diff --git a/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb b/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb index 3ce57c822..fffa06fd2 100644 --- a/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb +++ b/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb @@ -9,7 +9,7 @@ "\n", "This short tutorial covers a basic use case for {mod}`ott`:\n", "\n", - "- Compute a optimal transport distance between two point clouds using the {class}`~ott.geometry.point_cloud.PointCloud` geometry, solved using the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. \n", + "- Compute a optimal transport distance between two point clouds using the {class}`~ott.geometry.pointcloud.PointCloud` geometry, solved using the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. \n", "- Showcase the seamless integration with `JAX`, to differentiate through that cost and plot the gradient flow to morph the first point cloud into the second." ] }, @@ -367,7 +367,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.9" }, "vscode": { "interpreter": { diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 5e874f93b..e0712379a 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -175,17 +175,16 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: @jax.tree_util.register_pytree_node_class class SqPNorm(TICost): - """Squared p-norm of the difference of two vectors. + r"""Squared p-norm of the difference of two vectors. Args: - p: Power of the p-norm. + p: Power of the p-norm, :math:`\ge 1`. """ def __init__(self, p: float): super().__init__() - assert p >= 1.0, "p parameter in sq. p-norm should be >= 1.0" self.p = p - self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf + self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf def h(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * jnp.linalg.norm(z, self.p) ** 2 @@ -203,29 +202,28 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 del children - return cls(aux_data[0]) + return cls(*aux_data) @jax.tree_util.register_pytree_node_class class PNormP(TICost): - """p-norm to the power p (and divided by p) of the difference of two vectors. + r"""p-norm to the power p (and divided by p) of the difference of two vectors. Args: - p: Power of the p-norm, a finite float larger than 1.0. + p: Power of the p-norm in :math:`[1, +\infty)`. + Note that :func:`h_legendre` is not defined for ``p = 1``. """ def __init__(self, p: float): super().__init__() - assert p >= 1.0, "p parameter in p-norm should be larger than 1.0" - assert p < jnp.inf, "p parameter in p-norm should be finite" self.p = p - self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf + self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf def h(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.linalg.norm(z, self.p) ** self.p / self.p def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 - assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`" + # not defined for `p=1` return jnp.linalg.norm(z, self.q) ** self.q / self.q def tree_flatten(self): # noqa: D102 @@ -234,7 +232,7 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 del children - return cls(aux_data[0]) + return cls(*aux_data) @jax.tree_util.register_pytree_node_class @@ -337,12 +335,11 @@ class ElasticL1(RegTICost): \frac{1}{2} \|\cdot\|_2^2 + \gamma \|\cdot\|_1 Args: - gamma: Strength of the :math:`\|\cdot\|_1` regularization. + gamma: Strength of the :math:`\|\cdot\|_1` regularization, :math:`\ge 0`. """ def __init__(self, gamma: float = 1.0): super().__init__() - assert gamma >= 0, "Gamma must be non-negative." self.gamma = gamma def reg(self, z: jnp.ndarray) -> float: # noqa: D102 @@ -352,12 +349,12 @@ def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma) def tree_flatten(self): # noqa: D102 - return (), (self.gamma,) + return (self.gamma,), None @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - del children - return cls(*aux_data) + del aux_data + return cls(*children) @jax.tree_util.register_pytree_node_class @@ -373,12 +370,11 @@ class ElasticSTVS(RegTICost): where :math:`\sigma(\cdot) := \text{asinh}\left(\frac{\cdot}{2\gamma}\right)` Args: - gamma: Strength of the STVS regularization. + gamma: Strength of the STVS regularization, :math:`> 0`. """ # noqa def __init__(self, gamma: float = 1.0): super().__init__() - assert gamma > 0, "Gamma must be positive." self.gamma = gamma def reg(self, z: jnp.ndarray) -> float: # noqa: D102 @@ -390,12 +386,12 @@ def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z def tree_flatten(self): # noqa: D102 - return (), (self.gamma,) + return (self.gamma,), None @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - del children - return cls(*aux_data) + del aux_data + return cls(*children) @jax.tree_util.register_pytree_node_class @@ -412,12 +408,11 @@ class ElasticSqKOverlap(RegTICost): Args: k: Number of groups. Must be in ``[0, d)`` where :math:`d` is the dimensionality of the data. - gamma: Strength of the squared k-overlap norm regularization. + gamma: Strength of the squared k-overlap norm regularization, :math:`> 0`. """ def __init__(self, k: int, gamma: float = 1.0): super().__init__() - assert gamma > 0, "Gamma must be positive." self.k = k self.gamma = gamma @@ -488,12 +483,11 @@ def inner(r: int, l: int, return sgn * q[jnp.argsort(z_ixs.astype(float))] def tree_flatten(self): # noqa: D102 - return (), (self.k, self.gamma) + return (self.gamma,), {"k": self.k} @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - del children - return cls(*aux_data) + return cls(**aux_data, gamma=children[0]) @jax.tree_util.register_pytree_node_class