Skip to content

Commit

Permalink
changes in definition of Monge gap (#389)
Browse files Browse the repository at this point in the history
* changes in definition of Monge gap

* fixes following review

* fixes

* typos
  • Loading branch information
marcocuturi authored Jul 6, 2023
1 parent f55eb8e commit bfb9b51
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 45 deletions.
7 changes: 6 additions & 1 deletion docs/tutorials/notebooks/Monge_Gap.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -45,6 +46,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -70,6 +72,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -301,6 +304,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -485,7 +489,7 @@
" regularizer = None\n",
" else:\n",
" regularizer = jax.tree_util.Partial(\n",
" losses.monge_gap,\n",
" losses.monge_gap_from_samples,\n",
" cost_fn=cost_fn,\n",
" epsilon=EPSILON,\n",
" **SINKHORN_KWARGS,\n",
Expand Down Expand Up @@ -525,6 +529,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down
100 changes: 79 additions & 21 deletions src/ott/solvers/nn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Literal, Optional, Tuple, Union
from typing import Any, Callable, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from ott.geometry import costs, pointcloud
from ott.solvers.linear import sinkhorn

__all__ = ["monge_gap"]
__all__ = ["monge_gap", "monge_gap_from_samples"]


def monge_gap(
source: jnp.ndarray,
target: jnp.ndarray,
map_fn: Callable[[jnp.ndarray], jnp.ndarray],
reference_points: jnp.ndarray,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
Expand All @@ -36,9 +36,10 @@ def monge_gap(
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
r"""Monge gap regularizer :cite:`uscidda:23`.
For a cost function :math:`c` and an empirical reference :math:`\hat{\rho}_n`
defined by samples :math:`(x_i)_{i=1,\dots,n}`, the (entropic) Monge gap
of a vector field :math:`T` is defined as:
For a cost function :math:`c` and empirical reference measure
:math:`\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}`, the
(entropic) Monge gap of a map function
:math:`T:\mathbb{R}^d\rightarrow\mathbb{R}^d` is defined as:
.. math::
\mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T)
Expand All @@ -48,35 +49,92 @@ def monge_gap(
See :cite:`uscidda:23` Eq. (8).
Args:
source: samples from the reference measure :math:`\rho`,
array of shape ``[n, d]``.
target: samples from the mapped reference measure :math:`T \sharp \rho`
mapped with :math:`T`, i.e. samples from :math:`T \sharp \rho`,
array of shape ``[n, d]``.
map_fn: Callable corresponding to map :math:`T` in definition above. The
callable should be vectorized (e.g. using :func:`jax.vmap`), i.e,
able to process a *batch* of vectors of size `d`, namely
``map_fn`` applied to an array returns an array of the same shape.
reference_points: Array of `[n,d]` points, :math:`\hat\rho_n` in paper
cost_fn: An object of class :class:`~ott.geometry.costs.CostFn`.
epsilon: Regularization parameter. See
:class:`~ott.geometry.pointcloud.PointCloud`
relative_epsilon: when `False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When `True`, ``epsilon``
refers to a fraction of the
:attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is
computed adaptively using ``source`` and ``target`` points.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
kwargs: holds the kwargs to instantiate the or
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to
compute the regularized OT cost.
Returns:
The Monge gap value and optionally the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
"""
target = map_fn(reference_points)
return monge_gap_from_samples(
source=reference_points,
target=target,
cost_fn=cost_fn,
epsilon=epsilon,
relative_epsilon=relative_epsilon,
scale_cost=scale_cost,
**kwargs
)


def monge_gap_from_samples(
source: jnp.ndarray,
target: jnp.ndarray,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
r"""Monge gap, instantiated in terms of samples before / after applying map.
.. math::
\frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) -
W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i},
\frac{1}{n}\sum_i \delta_{y_i})
where :math:`W_{c, \varepsilon}` is an entropy-regularized optimal transport
cost, :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`
Args:
source: samples from first measure, array of shape ``[n, d]``.
target: samples from second measure, array of shape ``[n, d]``.
cost_fn: a cost function between two points in dimension :math:`d`.
If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used.
epsilon: Regularization parameter. If ``scale_epsilon = None`` and either
``relative_epsilon = True`` or ``relative_epsilon = None`` and
``epsilon = None`` in :class:`~ott.geometry.epsilon_scheduler.Epsilon`
is used, ``scale_epsilon`` is the
:attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`.
If ``epsilon = None``, use :math:`0.05`.
epsilon: Regularization parameter. See
:class:`~ott.geometry.pointcloud.PointCloud`
relative_epsilon: when `False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When `True`, ``epsilon``
refers to a fraction of the
:attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`, which is
computed adaptively from data.
computed adaptively using ``source`` and ``target`` points.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return Sinkhorn output.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
kwargs: holds the kwargs to instantiate the or
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to
compute the regularized OT cost.
Returns:
The Monge gap value and optionally the Sinkhorn output.
The Monge gap value and optionally the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
"""
cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
geom = pointcloud.PointCloud(
Expand Down
13 changes: 7 additions & 6 deletions src/ott/solvers/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ class ModelBase(abc.ABC, nn.Module):
@property
@abc.abstractmethod
def is_potential(self) -> bool:
"""Indicates if the module defines the potential's value or the gradient.
"""Indicates if the module implements a potential value or a vector field.
Returns:
``True`` if the module defines the potential's value, ``False``
if it defines the gradient.
``True`` if the module defines a potential, ``False`` if it defines a
vector field.
"""

def potential_value_fn(
Expand All @@ -89,7 +89,8 @@ def potential_value_fn(
potential. Only needed when :attr:`is_potential` is ``False``.
Returns:
A function that can be evaluated to obtain the potential's value
A function that can be evaluated to obtain a potential value, or a linear
interpolation of a potential.
"""
if self.is_potential:
return lambda x: self.apply({"params": params}, x)
Expand All @@ -113,7 +114,7 @@ def potential_gradient_fn(
self,
params: frozen_dict.FrozenDict[str, jnp.ndarray],
) -> PotentialGradientFn_t:
"""Return a function giving the gradient of the potential.
"""Return a function returning a vector or the gradient of the potential.
Args:
params: parameters of the module
Expand Down Expand Up @@ -302,7 +303,7 @@ def __call__(self, x: jnp.ndarray) -> float: # noqa: D102


class MLP(ModelBase):
"""A non-convex MLP.
"""A generic, typically not-convex (w.r.t input) MLP.
Args:
dim_hidden: sequence specifying size of hidden dimensions. The output
Expand Down
52 changes: 37 additions & 15 deletions tests/solvers/nn/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import pytest
from ott.geometry import costs
from ott.solvers.nn import losses
from ott.solvers.nn import losses, models


@pytest.mark.fast()
Expand All @@ -30,25 +30,40 @@ def test_monge_gap_non_negativity(

# generate data
rng1, rng2 = jax.random.split(rng, 2)
source = jax.random.normal(rng1, (n_samples, n_features))
target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3.
reference_points = jax.random.normal(rng1, (n_samples, n_features))

model = models.MLP(dim_hidden=[8, 8], is_potential=False)
params = model.init(rng2, x=reference_points[0])
target = model.apply(params, reference_points)

# compute the Monge gap
monge_gap_value = losses.monge_gap(source=source, target=target)
# compute the Monge gap based on samples
monge_gap_from_samples_value = losses.monge_gap_from_samples(
source=reference_points, target=target
)
np.testing.assert_array_equal(monge_gap_from_samples_value >= 0, True)

# Compute the Monge gap using model directly
monge_gap_value = losses.monge_gap(
map_fn=lambda x: model.apply(params, x),
reference_points=reference_points
)
np.testing.assert_array_equal(monge_gap_value >= 0, True)

np.testing.assert_array_equal(monge_gap_value, monge_gap_from_samples_value)

def test_monge_gap_jit(self, rng: jax.random.PRNGKey):
n_samples, n_features = 31, 17
# generate data
rng1, rng2 = jax.random.split(rng, 2)
source = jax.random.normal(rng1, (n_samples, n_features))
target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3.

target = jax.random.normal(rng2, (n_samples, n_features))
# define jitted monge gap
jit_monge_gap = jax.jit(losses.monge_gap)
jit_monge_gap = jax.jit(losses.monge_gap_from_samples)

# compute the Monge gaps for different costs
monge_gap_value = losses.monge_gap(source, target)
monge_gap_value = losses.monge_gap_from_samples(
source=source, target=target
)
jit_monge_gap_value = jit_monge_gap(source, target)
np.testing.assert_allclose(monge_gap_value, jit_monge_gap_value, rtol=1e-3)

Expand All @@ -67,7 +82,7 @@ def test_monge_gap_jit(self, rng: jax.random.PRNGKey):
"stvs-gam2",
],
)
def test_monge_gap_different_cost(
def test_monge_gap_from_samples_different_cost(
self, rng: jax.random.PRNGKeyArray, cost_fn: costs.CostFn, n_samples: int,
n_features: int
):
Expand All @@ -84,17 +99,24 @@ def test_monge_gap_different_cost(
target = jax.random.normal(rng2, (n_samples, n_features)) * .1 + 3.

# compute the Monge gaps for the euclidean cost
monge_gap_value_eucl = losses.monge_gap(
monge_gap_from_samples_value_eucl = losses.monge_gap_from_samples(
source=source, target=target, cost_fn=costs.Euclidean()
)
monge_gap_value_cost_fn = losses.monge_gap(
monge_gap_from_samples_value_cost_fn = losses.monge_gap_from_samples(
source=source, target=target, cost_fn=cost_fn
)

with pytest.raises(AssertionError, match=r"tolerance"):
np.testing.assert_allclose(
monge_gap_value_eucl, monge_gap_value_cost_fn, rtol=1e-1, atol=1e-1
monge_gap_from_samples_value_eucl,
monge_gap_from_samples_value_cost_fn,
rtol=1e-1,
atol=1e-1
)

np.testing.assert_array_equal(np.isfinite(monge_gap_value_eucl), True)
np.testing.assert_array_equal(np.isfinite(monge_gap_value_cost_fn), True)
np.testing.assert_array_equal(
np.isfinite(monge_gap_from_samples_value_eucl), True
)
np.testing.assert_array_equal(
np.isfinite(monge_gap_from_samples_value_cost_fn), True
)
4 changes: 2 additions & 2 deletions tests/tools/map_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestMapEstimator:
def test_map_estimator_convergence(self):
"""Tests convergence of a simple
map estimator with Sinkhorn divergence fitting loss
and Monge gap regularizer.
and Monge (coupling) gap regularizer.
"""

# define the fitting loss and the regularizer
Expand All @@ -43,7 +43,7 @@ def fitting_loss(
y=mapped_samples,
).divergence

regularizer = losses.monge_gap
regularizer = losses.monge_gap_from_samples

# define the model
model = models.MLP(dim_hidden=[64, 32], is_potential=False)
Expand Down

0 comments on commit bfb9b51

Please sign in to comment.