Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove assertions in cost functions #340

Merged
merged 7 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/tutorials/notebooks/basic_ot_between_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"\n",
"This short tutorial covers a basic use case for {mod}`ott`:\n",
"\n",
"- Compute a optimal transport distance between two point clouds using the {class}`~ott.geometry.point_cloud.PointCloud` geometry, solved using the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. \n",
"- Compute a optimal transport distance between two point clouds using the {class}`~ott.geometry.pointcloud.PointCloud` geometry, solved using the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. \n",
"- Showcase the seamless integration with `JAX`, to differentiate through that cost and plot the gradient flow to morph the first point cloud into the second."
]
},
Expand Down Expand Up @@ -367,7 +367,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.9"
},
"vscode": {
"interpreter": {
Expand Down
48 changes: 21 additions & 27 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,16 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:

@jax.tree_util.register_pytree_node_class
class SqPNorm(TICost):
"""Squared p-norm of the difference of two vectors.
r"""Squared p-norm of the difference of two vectors.

Args:
p: Power of the p-norm.
p: Power of the p-norm, :math:`\ge 1`.
"""

def __init__(self, p: float):
super().__init__()
assert p >= 1.0, "p parameter in sq. p-norm should be >= 1.0"
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf
self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf

def h(self, z: jnp.ndarray) -> float: # noqa: D102
return 0.5 * jnp.linalg.norm(z, self.p) ** 2
Expand All @@ -203,29 +202,28 @@ def tree_flatten(self): # noqa: D102
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0])
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
class PNormP(TICost):
"""p-norm to the power p (and divided by p) of the difference of two vectors.
r"""p-norm to the power p (and divided by p) of the difference of two vectors.

Args:
p: Power of the p-norm, a finite float larger than 1.0.
p: Power of the p-norm in :math:`[1, +\infty)`.
Note that :func:`h_legendre` is not defined for ``p = 1``.
"""

def __init__(self, p: float):
super().__init__()
assert p >= 1.0, "p parameter in p-norm should be larger than 1.0"
assert p < jnp.inf, "p parameter in p-norm should be finite"
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf
self.q = 1.0 / (1.0 - (1.0 / p)) if p > 1.0 else jnp.inf

def h(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.linalg.norm(z, self.p) ** self.p / self.p

def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`"
# not defined for `p=1`
return jnp.linalg.norm(z, self.q) ** self.q / self.q

def tree_flatten(self): # noqa: D102
Expand All @@ -234,7 +232,7 @@ def tree_flatten(self): # noqa: D102
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0])
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
Expand Down Expand Up @@ -337,12 +335,11 @@ class ElasticL1(RegTICost):
\frac{1}{2} \|\cdot\|_2^2 + \gamma \|\cdot\|_1

Args:
gamma: Strength of the :math:`\|\cdot\|_1` regularization.
gamma: Strength of the :math:`\|\cdot\|_1` regularization, :math:`\ge 0`.
"""

def __init__(self, gamma: float = 1.0):
super().__init__()
assert gamma >= 0, "Gamma must be non-negative."
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float: # noqa: D102
Expand All @@ -352,12 +349,12 @@ def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma)

def tree_flatten(self): # noqa: D102
return (), (self.gamma,)
return (self.gamma,), None

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)
del aux_data
return cls(*children)


@jax.tree_util.register_pytree_node_class
Expand All @@ -373,12 +370,11 @@ class ElasticSTVS(RegTICost):
where :math:`\sigma(\cdot) := \text{asinh}\left(\frac{\cdot}{2\gamma}\right)`

Args:
gamma: Strength of the STVS regularization.
gamma: Strength of the STVS regularization, :math:`> 0`.
""" # noqa

def __init__(self, gamma: float = 1.0):
super().__init__()
assert gamma > 0, "Gamma must be positive."
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float: # noqa: D102
Expand All @@ -390,12 +386,12 @@ def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102
return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z

def tree_flatten(self): # noqa: D102
return (), (self.gamma,)
return (self.gamma,), None

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)
del aux_data
return cls(*children)


@jax.tree_util.register_pytree_node_class
Expand All @@ -412,12 +408,11 @@ class ElasticSqKOverlap(RegTICost):
Args:
k: Number of groups. Must be in ``[0, d)`` where :math:`d` is the
dimensionality of the data.
gamma: Strength of the squared k-overlap norm regularization.
gamma: Strength of the squared k-overlap norm regularization, :math:`> 0`.
"""

def __init__(self, k: int, gamma: float = 1.0):
super().__init__()
assert gamma > 0, "Gamma must be positive."
self.k = k
self.gamma = gamma

Expand Down Expand Up @@ -488,12 +483,11 @@ def inner(r: int, l: int,
return sgn * q[jnp.argsort(z_ixs.astype(float))]

def tree_flatten(self): # noqa: D102
return (), (self.k, self.gamma)
return (self.gamma,), {"k": self.k}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)
return cls(**aux_data, gamma=children[0])


@jax.tree_util.register_pytree_node_class
Expand Down