diff --git a/docs/tutorials/notebooks/Monge_Gap.ipynb b/docs/tutorials/notebooks/Monge_Gap.ipynb index 50d1c7030..7ad3fcbb8 100644 --- a/docs/tutorials/notebooks/Monge_Gap.ipynb +++ b/docs/tutorials/notebooks/Monge_Gap.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -45,6 +46,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -70,6 +72,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -301,6 +304,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -485,7 +489,7 @@ " regularizer = None\n", " else:\n", " regularizer = jax.tree_util.Partial(\n", - " losses.monge_gap,\n", + " losses.monge_gap_from_samples,\n", " cost_fn=cost_fn,\n", " epsilon=EPSILON,\n", " **SINKHORN_KWARGS,\n", @@ -525,6 +529,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/src/ott/solvers/nn/losses.py b/src/ott/solvers/nn/losses.py index 265396016..82b869a86 100644 --- a/src/ott/solvers/nn/losses.py +++ b/src/ott/solvers/nn/losses.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional, Tuple, Union +from typing import Any, Callable, Literal, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -20,12 +20,12 @@ from ott.geometry import costs, pointcloud from ott.solvers.linear import sinkhorn -__all__ = ["monge_gap"] +__all__ = ["monge_gap", "monge_gap_from_samples"] def monge_gap( - source: jnp.ndarray, - target: jnp.ndarray, + map_fn: Callable[[jnp.ndarray], jnp.ndarray], + reference_points: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, @@ -36,9 +36,10 @@ def monge_gap( ) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: r"""Monge gap regularizer :cite:`uscidda:23`. - For a cost function :math:`c` and an empirical reference :math:`\hat{\rho}_n` - defined by samples :math:`(x_i)_{i=1,\dots,n}`, the (entropic) Monge gap - of a vector field :math:`T` is defined as: + For a cost function :math:`c` and empirical reference measure + :math:`\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}`, the + (entropic) Monge gap of a map function + :math:`T:\mathbb{R}^d\rightarrow\mathbb{R}^d` is defined as: .. math:: \mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T) @@ -48,35 +49,92 @@ def monge_gap( See :cite:`uscidda:23` Eq. (8). Args: - source: samples from the reference measure :math:`\rho`, - array of shape ``[n, d]``. - target: samples from the mapped reference measure :math:`T \sharp \rho` - mapped with :math:`T`, i.e. samples from :math:`T \sharp \rho`, - array of shape ``[n, d]``. + map_fn: Callable corresponding to map :math:`T` in definition above. The + callable should be vectorized (e.g. using :func:`jax.vmap`), i.e, + able to process a *batch* of vectors of size `d`, namely + ``map_fn`` applied to an array returns an array of the same shape. + reference_points: Array of `[n,d]` points, :math:`\hat\rho_n` in paper + cost_fn: An object of class :class:`~ott.geometry.costs.CostFn`. + epsilon: Regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud` + relative_epsilon: when `False`, the parameter ``epsilon`` specifies the + value of the entropic regularization parameter. When `True`, ``epsilon`` + refers to a fraction of the + :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is + computed adaptively using ``source`` and ``target`` points. + scale_cost: option to rescale the cost matrix. Implemented scalings are + 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be + given to rescale the cost such that ``cost_matrix /= scale_cost``. + If `True`, use 'mean'. + return_output: boolean to also return the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` + kwargs: holds the kwargs to instantiate the or + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to + compute the regularized OT cost. + + Returns: + The Monge gap value and optionally the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` + """ + target = map_fn(reference_points) + return monge_gap_from_samples( + source=reference_points, + target=target, + cost_fn=cost_fn, + epsilon=epsilon, + relative_epsilon=relative_epsilon, + scale_cost=scale_cost, + **kwargs + ) + + +def monge_gap_from_samples( + source: jnp.ndarray, + target: jnp.ndarray, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + relative_epsilon: Optional[bool] = None, + scale_cost: Union[bool, int, float, Literal["mean", "max_cost", + "median"]] = 1.0, + return_output: bool = False, + **kwargs: Any +) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: + r"""Monge gap, instantiated in terms of samples before / after applying map. + + .. math:: + + \frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) - + W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i}, + \frac{1}{n}\sum_i \delta_{y_i}) + + where :math:`W_{c, \varepsilon}` is an entropy-regularized optimal transport + cost, :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost` + + Args: + source: samples from first measure, array of shape ``[n, d]``. + target: samples from second measure, array of shape ``[n, d]``. cost_fn: a cost function between two points in dimension :math:`d`. If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. - epsilon: Regularization parameter. If ``scale_epsilon = None`` and either - ``relative_epsilon = True`` or ``relative_epsilon = None`` and - ``epsilon = None`` in :class:`~ott.geometry.epsilon_scheduler.Epsilon` - is used, ``scale_epsilon`` is the - :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`. - If ``epsilon = None``, use :math:`0.05`. + epsilon: Regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud` relative_epsilon: when `False`, the parameter ``epsilon`` specifies the value of the entropic regularization parameter. When `True`, ``epsilon`` refers to a fraction of the :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is - computed adaptively from data. + computed adaptively using ``source`` and ``target`` points. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. If `True`, use 'mean'. - return_output: boolean to also return Sinkhorn output. + return_output: boolean to also return the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` kwargs: holds the kwargs to instantiate the or :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to compute the regularized OT cost. Returns: - The Monge gap value and optionally the Sinkhorn output. + The Monge gap value and optionally the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` """ cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn geom = pointcloud.PointCloud( diff --git a/src/ott/solvers/nn/models.py b/src/ott/solvers/nn/models.py index 6066cfd7f..227084619 100644 --- a/src/ott/solvers/nn/models.py +++ b/src/ott/solvers/nn/models.py @@ -58,11 +58,11 @@ class ModelBase(abc.ABC, nn.Module): @property @abc.abstractmethod def is_potential(self) -> bool: - """Indicates if the module defines the potential's value or the gradient. + """Indicates if the module implements a potential value or a vector field. Returns: - ``True`` if the module defines the potential's value, ``False`` - if it defines the gradient. + ``True`` if the module defines a potential, ``False`` if it defines a + vector field. """ def potential_value_fn( @@ -89,7 +89,8 @@ def potential_value_fn( potential. Only needed when :attr:`is_potential` is ``False``. Returns: - A function that can be evaluated to obtain the potential's value + A function that can be evaluated to obtain a potential value, or a linear + interpolation of a potential. """ if self.is_potential: return lambda x: self.apply({"params": params}, x) @@ -113,7 +114,7 @@ def potential_gradient_fn( self, params: frozen_dict.FrozenDict[str, jnp.ndarray], ) -> PotentialGradientFn_t: - """Return a function giving the gradient of the potential. + """Return a function returning a vector or the gradient of the potential. Args: params: parameters of the module @@ -302,7 +303,7 @@ def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 class MLP(ModelBase): - """A non-convex MLP. + """A generic, typically not-convex (w.r.t input) MLP. Args: dim_hidden: sequence specifying size of hidden dimensions. The output diff --git a/tests/solvers/nn/losses_test.py b/tests/solvers/nn/losses_test.py index 2225ad1b8..2b50df954 100644 --- a/tests/solvers/nn/losses_test.py +++ b/tests/solvers/nn/losses_test.py @@ -16,7 +16,7 @@ import numpy as np import pytest from ott.geometry import costs -from ott.solvers.nn import losses +from ott.solvers.nn import losses, models @pytest.mark.fast() @@ -30,25 +30,40 @@ def test_monge_gap_non_negativity( # generate data rng1, rng2 = jax.random.split(rng, 2) - source = jax.random.normal(rng1, (n_samples, n_features)) - target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3. + reference_points = jax.random.normal(rng1, (n_samples, n_features)) + + model = models.MLP(dim_hidden=[8, 8], is_potential=False) + params = model.init(rng2, x=reference_points[0]) + target = model.apply(params, reference_points) - # compute the Monge gap - monge_gap_value = losses.monge_gap(source=source, target=target) + # compute the Monge gap based on samples + monge_gap_from_samples_value = losses.monge_gap_from_samples( + source=reference_points, target=target + ) + np.testing.assert_array_equal(monge_gap_from_samples_value >= 0, True) + + # Compute the Monge gap using model directly + monge_gap_value = losses.monge_gap( + map_fn=lambda x: model.apply(params, x), + reference_points=reference_points + ) np.testing.assert_array_equal(monge_gap_value >= 0, True) + np.testing.assert_array_equal(monge_gap_value, monge_gap_from_samples_value) + def test_monge_gap_jit(self, rng: jax.random.PRNGKey): n_samples, n_features = 31, 17 # generate data rng1, rng2 = jax.random.split(rng, 2) source = jax.random.normal(rng1, (n_samples, n_features)) - target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3. - + target = jax.random.normal(rng2, (n_samples, n_features)) # define jitted monge gap - jit_monge_gap = jax.jit(losses.monge_gap) + jit_monge_gap = jax.jit(losses.monge_gap_from_samples) # compute the Monge gaps for different costs - monge_gap_value = losses.monge_gap(source, target) + monge_gap_value = losses.monge_gap_from_samples( + source=source, target=target + ) jit_monge_gap_value = jit_monge_gap(source, target) np.testing.assert_allclose(monge_gap_value, jit_monge_gap_value, rtol=1e-3) @@ -67,7 +82,7 @@ def test_monge_gap_jit(self, rng: jax.random.PRNGKey): "stvs-gam2", ], ) - def test_monge_gap_different_cost( + def test_monge_gap_from_samples_different_cost( self, rng: jax.random.PRNGKeyArray, cost_fn: costs.CostFn, n_samples: int, n_features: int ): @@ -84,17 +99,24 @@ def test_monge_gap_different_cost( target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3. # compute the Monge gaps for the euclidean cost - monge_gap_value_eucl = losses.monge_gap( + monge_gap_from_samples_value_eucl = losses.monge_gap_from_samples( source=source, target=target, cost_fn=costs.Euclidean() ) - monge_gap_value_cost_fn = losses.monge_gap( + monge_gap_from_samples_value_cost_fn = losses.monge_gap_from_samples( source=source, target=target, cost_fn=cost_fn ) with pytest.raises(AssertionError, match=r"tolerance"): np.testing.assert_allclose( - monge_gap_value_eucl, monge_gap_value_cost_fn, rtol=1e-1, atol=1e-1 + monge_gap_from_samples_value_eucl, + monge_gap_from_samples_value_cost_fn, + rtol=1e-1, + atol=1e-1 ) - np.testing.assert_array_equal(np.isfinite(monge_gap_value_eucl), True) - np.testing.assert_array_equal(np.isfinite(monge_gap_value_cost_fn), True) + np.testing.assert_array_equal( + np.isfinite(monge_gap_from_samples_value_eucl), True + ) + np.testing.assert_array_equal( + np.isfinite(monge_gap_from_samples_value_cost_fn), True + ) diff --git a/tests/tools/map_estimator_test.py b/tests/tools/map_estimator_test.py index 364a1122a..5cbec8373 100644 --- a/tests/tools/map_estimator_test.py +++ b/tests/tools/map_estimator_test.py @@ -28,7 +28,7 @@ class TestMapEstimator: def test_map_estimator_convergence(self): """Tests convergence of a simple map estimator with Sinkhorn divergence fitting loss - and Monge gap regularizer. + and Monge (coupling) gap regularizer. """ # define the fitting loss and the regularizer @@ -43,7 +43,7 @@ def fitting_loss( y=mapped_samples, ).divergence - regularizer = losses.monge_gap + regularizer = losses.monge_gap_from_samples # define the model model = models.MLP(dim_hidden=[64, 32], is_potential=False)