diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst index e7b08327e..ac075ab04 100644 --- a/docs/solvers/quadratic.rst +++ b/docs/solvers/quadratic.rst @@ -17,7 +17,7 @@ Gromov-Wasserstein Solvers gromov_wasserstein.GWOutput gromov_wasserstein_lr.LRGromovWasserstein gromov_wasserstein_lr.LRGWOutput - lower_bound.LowerBoundSolver + lower_bound.third_lower_bound Barycenter Solvers diff --git a/src/ott/solvers/quadratic/lower_bound.py b/src/ott/solvers/quadratic/lower_bound.py index f0868ad2a..ed2ac36c1 100644 --- a/src/ott/solvers/quadratic/lower_bound.py +++ b/src/ott/solvers/quadratic/lower_bound.py @@ -13,9 +13,6 @@ # limitations under the License. from typing import TYPE_CHECKING, Any, Optional -import jax -import jax.tree_util as jtu - from ott.geometry import pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers import linear @@ -24,73 +21,38 @@ if TYPE_CHECKING: from ott.geometry import distrib_costs -__all__ = ["LowerBoundSolver"] - +__all__ = ["third_lower_bound"] -@jtu.register_pytree_node_class -class LowerBoundSolver: - """Lower bound OT solver. - Computes the third lower bound distance from :cite:`memoli:11`, def. 6.3. +def third_lower_bound( + prob: quadratic_problem.QuadraticProblem, + distrib_cost: Optional["distrib_costs.UnivariateWasserstein"] = None, + epsilon: Optional[float] = None, + **kwargs: Any, +) -> sinkhorn.SinkhornOutput: + """Computes the third lower bound distance from :cite:`memoli:11`, def. 6.3. Args: - epsilon: Entropy regularization for the resulting linear problem. - distrib_cost: Univariate Wasserstein cost, used to compare two point clouds - in different spaces, where each point is seen as its distribution of costs - to other points in its point cloud. + prob: Quadratic OT problem. + distrib_cost: Univariate Wasserstein cost used to compare two point clouds + in different spaces. Each point is seen as its distribution of costs + to other points in its respective point cloud. + epsilon: Entropy regularization. + kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`. + + Returns: + An approximation of the GW coupling that can be used to initialize + the solution of the quadratic OT problem. """ + from ott.geometry import distrib_costs - def __init__( - self, - epsilon: Optional[float] = None, - distrib_cost: Optional["distrib_costs.UnivariateWasserstein"] = None, - ): - from ott.geometry import distrib_costs - - self.epsilon = epsilon - self.distrib_cost = ( - distrib_costs.UnivariateWasserstein() - if distrib_cost is None else distrib_cost - ) - - def __call__( - self, - prob: quadratic_problem.QuadraticProblem, - epsilon: Optional[float] = None, - rng: Optional[jax.Array] = None, - **kwargs: Any - ) -> sinkhorn.SinkhornOutput: - """Compute a lower-bound for the GW problem using a simple linearization. - - This solver handles a quadratic problem by computing a proxy ``[n, m]`` - cost-matrix, injecting it into a linear OT solver to output a first an OT - matrix that can be used either to linearize/initialize the resolution - ot the GW problem, or more simply as a simple GW solution. - - Args: - prob: Quadratic OT problem. - epsilon: Entropic regularization passed on to solve the linearization of - the quadratic problem using 1D costs. - rng: Random key, possibly used when computing 1D costs when using - subsampling. - kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`. - - Returns: - A linear OT output, an approximation of the OT coupling obtained using - the lower bound provided by :cite:`memoli:11`. - """ - dists_xx = prob.geom_xx.cost_matrix - dists_yy = prob.geom_yy.cost_matrix - - geom_xy = pointcloud.PointCloud( - dists_xx, dists_yy, cost_fn=self.distrib_cost, epsilon=self.epsilon - ) - return linear.solve(geom_xy, **kwargs) + if distrib_cost is None: + distrib_cost = distrib_costs.UnivariateWasserstein() - def tree_flatten(self): # noqa: D102 - return (self.epsilon, self.distrib_cost), None + dists_xx = prob.geom_xx.cost_matrix + dists_yy = prob.geom_yy.cost_matrix + geom_xy = pointcloud.PointCloud( + dists_xx, dists_yy, cost_fn=distrib_cost, epsilon=epsilon + ) - @classmethod - def tree_unflatten(cls, aux_data, children): # noqa: D102 - del aux_data - return cls(*children) + return linear.solve(geom_xy, **kwargs) diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index 2766e564d..7e8a7a160 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -20,7 +20,7 @@ from ott.solvers.quadratic import lower_bound -class TestLowerBoundSolver: +class TestLowerBound: @pytest.fixture(autouse=True) def initialize(self, rng: jax.Array): @@ -30,7 +30,6 @@ def initialize(self, rng: jax.Array): rngs = jax.random.split(rng, 4) self.x = jax.random.uniform(rngs[0], (self.n, d_x)) self.y = jax.random.uniform(rngs[1], (self.m, d_y)) - # Currently the Lower Bound only supports uniform distributions: a = jnp.ones(self.n) b = jnp.ones(self.m) self.a = a / jnp.sum(a) @@ -52,10 +51,8 @@ def test_lb_pointcloud(self, ground_cost: costs.TICost): geom_x, geom_y, a=self.a, b=self.b ) distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost) - solver = lower_bound.LowerBoundSolver( - epsilon=1e-1, distrib_cost=distrib_cost - ) - out = jax.jit(solver)(prob) + out = jax.jit(lower_bound.third_lower_bound + )(prob, distrib_cost, epsilon=1e-1) - assert not jnp.isnan(out.reg_ot_cost) + assert jnp.isfinite(out.reg_ot_cost)