Skip to content

Commit

Permalink
Maint/third lower bound (#507)
Browse files Browse the repository at this point in the history
* Use `jax.debug.callback`

* Update tutorial

* Clean `LowerBoundSolver`
  • Loading branch information
michalk8 authored Mar 27, 2024
1 parent 14d4b81 commit b6ea832
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 73 deletions.
2 changes: 1 addition & 1 deletion docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Gromov-Wasserstein Solvers
gromov_wasserstein.GWOutput
gromov_wasserstein_lr.LRGromovWasserstein
gromov_wasserstein_lr.LRGWOutput
lower_bound.LowerBoundSolver
lower_bound.third_lower_bound


Barycenter Solvers
Expand Down
92 changes: 27 additions & 65 deletions src/ott/solvers/quadratic/lower_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional

import jax
import jax.tree_util as jtu

from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers import linear
Expand All @@ -24,73 +21,38 @@
if TYPE_CHECKING:
from ott.geometry import distrib_costs

__all__ = ["LowerBoundSolver"]

__all__ = ["third_lower_bound"]

@jtu.register_pytree_node_class
class LowerBoundSolver:
"""Lower bound OT solver.

Computes the third lower bound distance from :cite:`memoli:11`, def. 6.3.
def third_lower_bound(
prob: quadratic_problem.QuadraticProblem,
distrib_cost: Optional["distrib_costs.UnivariateWasserstein"] = None,
epsilon: Optional[float] = None,
**kwargs: Any,
) -> sinkhorn.SinkhornOutput:
"""Computes the third lower bound distance from :cite:`memoli:11`, def. 6.3.
Args:
epsilon: Entropy regularization for the resulting linear problem.
distrib_cost: Univariate Wasserstein cost, used to compare two point clouds
in different spaces, where each point is seen as its distribution of costs
to other points in its point cloud.
prob: Quadratic OT problem.
distrib_cost: Univariate Wasserstein cost used to compare two point clouds
in different spaces. Each point is seen as its distribution of costs
to other points in its respective point cloud.
epsilon: Entropy regularization.
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`.
Returns:
An approximation of the GW coupling that can be used to initialize
the solution of the quadratic OT problem.
"""
from ott.geometry import distrib_costs

def __init__(
self,
epsilon: Optional[float] = None,
distrib_cost: Optional["distrib_costs.UnivariateWasserstein"] = None,
):
from ott.geometry import distrib_costs

self.epsilon = epsilon
self.distrib_cost = (
distrib_costs.UnivariateWasserstein()
if distrib_cost is None else distrib_cost
)

def __call__(
self,
prob: quadratic_problem.QuadraticProblem,
epsilon: Optional[float] = None,
rng: Optional[jax.Array] = None,
**kwargs: Any
) -> sinkhorn.SinkhornOutput:
"""Compute a lower-bound for the GW problem using a simple linearization.
This solver handles a quadratic problem by computing a proxy ``[n, m]``
cost-matrix, injecting it into a linear OT solver to output a first an OT
matrix that can be used either to linearize/initialize the resolution
ot the GW problem, or more simply as a simple GW solution.
Args:
prob: Quadratic OT problem.
epsilon: Entropic regularization passed on to solve the linearization of
the quadratic problem using 1D costs.
rng: Random key, possibly used when computing 1D costs when using
subsampling.
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`.
Returns:
A linear OT output, an approximation of the OT coupling obtained using
the lower bound provided by :cite:`memoli:11`.
"""
dists_xx = prob.geom_xx.cost_matrix
dists_yy = prob.geom_yy.cost_matrix

geom_xy = pointcloud.PointCloud(
dists_xx, dists_yy, cost_fn=self.distrib_cost, epsilon=self.epsilon
)
return linear.solve(geom_xy, **kwargs)
if distrib_cost is None:
distrib_cost = distrib_costs.UnivariateWasserstein()

def tree_flatten(self): # noqa: D102
return (self.epsilon, self.distrib_cost), None
dists_xx = prob.geom_xx.cost_matrix
dists_yy = prob.geom_yy.cost_matrix
geom_xy = pointcloud.PointCloud(
dists_xx, dists_yy, cost_fn=distrib_cost, epsilon=epsilon
)

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)
return linear.solve(geom_xy, **kwargs)
11 changes: 4 additions & 7 deletions tests/solvers/quadratic/lower_bound_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ott.solvers.quadratic import lower_bound


class TestLowerBoundSolver:
class TestLowerBound:

@pytest.fixture(autouse=True)
def initialize(self, rng: jax.Array):
Expand All @@ -30,7 +30,6 @@ def initialize(self, rng: jax.Array):
rngs = jax.random.split(rng, 4)
self.x = jax.random.uniform(rngs[0], (self.n, d_x))
self.y = jax.random.uniform(rngs[1], (self.m, d_y))
# Currently the Lower Bound only supports uniform distributions:
a = jnp.ones(self.n)
b = jnp.ones(self.m)
self.a = a / jnp.sum(a)
Expand All @@ -52,10 +51,8 @@ def test_lb_pointcloud(self, ground_cost: costs.TICost):
geom_x, geom_y, a=self.a, b=self.b
)
distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost)
solver = lower_bound.LowerBoundSolver(
epsilon=1e-1, distrib_cost=distrib_cost
)

out = jax.jit(solver)(prob)
out = jax.jit(lower_bound.third_lower_bound
)(prob, distrib_cost, epsilon=1e-1)

assert not jnp.isnan(out.reg_ot_cost)
assert jnp.isfinite(out.reg_ot_cost)

0 comments on commit b6ea832

Please sign in to comment.