diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 67b3edae1..dcbccf36d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -45,10 +45,12 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install dependencies - # `jax[cuda]<0.4` because of: https://github.com/google/jax/issues/13758 + # `jax[cuda]<0.4` because of Docker issues: https://github.com/google/jax/issues/13758 + # `flax<0.6.5` because it requires `jax>=0.4.2` run: | python3 -m pip install --upgrade pip python3 -m pip install -e".[test]" + python3 -m pip install "flax<0.6.5" python3 -m pip install "jax[cuda]<0.4" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - name: Nvidia SMI diff --git a/docs/geometry.rst b/docs/geometry.rst index 85190a545..50cc34d93 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -52,8 +52,6 @@ Cost Functions .. autosummary:: :toctree: _autosummary - costs.CostFn - costs.TICost costs.SqPNorm costs.PNormP costs.SqEuclidean @@ -61,6 +59,9 @@ Cost Functions costs.Cosine costs.Bures costs.UnbalancedBures + costs.ElasticL1 + costs.ElasticSTVS + costs.ElasticSqKOverlap Utilities --------- diff --git a/docs/index.rst b/docs/index.rst index cc355c27c..6b6ab3780 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,13 +15,13 @@ as differentiable approximations to ranking or even clustering. To achieve this, ``OTT`` rests on two families of tools: - - the first family consists in *discrete* solvers computing transport between - point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn - :cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein - :cite:`memoli:11,peyre:16`; - - the second family consists in *continuous* solvers, using suitable neural - architectures :cite:`amos:17` coupled with SGD type estimators - :cite:`makkuva:20,korotin:21`. +- the first family consists in *discrete* solvers computing transport between + point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn + :cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein + :cite:`memoli:11,peyre:16`; +- the second family consists in *continuous* solvers, using suitable neural + architectures :cite:`amos:17` coupled with SGD type estimators + :cite:`makkuva:20,korotin:21`. Installation ------------ diff --git a/docs/references.bib b/docs/references.bib index 00cf4996a..00f86a5bc 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -703,3 +703,40 @@ @ARTICLE{chen:20 pages={2133-2147}, doi={10.1109/TPAMI.2019.2908635} } + +@ARTICLE{schreck:15, + author={Schreck, Amandine and Fort, Gersende and Le Corff, Sylvain and Moulines, Eric}, + journal={IEEE Journal of Selected Topics in Signal Processing}, + title={A Shrinkage-Thresholding Metropolis Adjusted Langevin Algorithm for Bayesian Variable Selection}, + year={2016}, + volume={10}, + number={2}, + pages={366-375}, + doi={10.1109/JSTSP.2015.2496546} +} + +@inproceedings{argyriou:12, + author = {Argyriou, Andreas and Foygel, Rina and Srebro, Nathan}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {F. Pereira and C.J. Burges and L. Bottou and K.Q. Weinberger}, + pages = {}, + publisher = {Curran Associates, Inc.}, + title = {Sparse Prediction with the k-Support Norm}, + url = {https://proceedings.neurips.cc/paper/2012/file/99bcfcd754a98ce89cb86f73acc04645-Paper.pdf}, + volume = {25}, + year = {2012} +} + +@article{zou:05, + author = {Zou, Hui and Hastie, Trevor}, + title = {Regularization and variable selection via the elastic net}, + journal = {Journal of the Royal Statistical Society: Series B (Statistical Methodology)}, + volume = {67}, + number = {2}, + pages = {301-320}, + keywords = {Grouping effect, LARS algorithm, Lasso, Penalization, p≫n problem, Variable selection}, + doi = {https://doi.org/10.1111/j.1467-9868.2005.00503.x}, + url = {https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.1467-9868.2005.00503.x}, + eprint = {https://rss.onlinelibrary.wiley.com/doi/pdf/10.1111/j.1467-9868.2005.00503.x}, + year = {2005} +} diff --git a/pyproject.toml b/pyproject.toml index 261ee4ac1..2b98fec77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ] [project.urls] diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 2de72f682..09838a4ef 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -23,14 +23,14 @@ from ott.math import fixed_point_loop, matrix_square_root __all__ = [ - "PNormP", "SqPNorm", "Euclidean", "SqEuclidean", "Cosine", "Bures", - "UnbalancedBures" + "PNormP", "SqPNorm", "Euclidean", "SqEuclidean", "Cosine", "ElasticL1", + "ElasticSTVS", "ElasticSqKOverlap", "Bures", "UnbalancedBures" ] @jax.tree_util.register_pytree_node_class class CostFn(abc.ABC): - """A generic cost function, taking two vectors as input. + """Base class for all costs. Cost functions evaluate a function on a pair of inputs. For convenience, that function is split into two norms -- evaluated on each input separately -- @@ -39,14 +39,14 @@ class CostFn(abc.ABC): ``c(x,y) = norm(x) + norm(y) + pairwise(x,y)`` If the :attr:`norm` function is not implemented, that value is handled as a 0, - and only :attr:`pairwise` is used. + and only :func:`pairwise` is used. """ # no norm function created by default. norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None @abc.abstractmethod - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pass def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @@ -59,7 +59,7 @@ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: Returns: The barycenter of `xs` using `weights` coefficients. """ - raise NotImplementedError("Barycenter is not yet implemented.") + raise NotImplementedError("Barycenter is not implemented.") @classmethod def _padder(cls, dim: int) -> jnp.ndarray: @@ -73,7 +73,7 @@ def _padder(cls, dim: int) -> jnp.ndarray: """ return jnp.zeros((1, dim)) - def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: cost = self.pairwise(x, y) if self.norm is None: return cost @@ -112,7 +112,7 @@ def tree_unflatten(cls, aux_data, children): @jax.tree_util.register_pytree_node_class class TICost(CostFn): - """A class for translation invariant (TI) costs. + """Base class for translation invariant (TI) costs. Such costs are defined using a function :math:`h`, mapping vectors to real-values, to be used as: @@ -135,7 +135,7 @@ def h_legendre(self, z: jnp.ndarray) -> float: """Legendre transform of :func:`h` when it is convex.""" raise NotImplementedError("`h_legendre` not implemented.") - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) @@ -144,9 +144,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: class SqPNorm(TICost): """Squared p-norm of the difference of two vectors. - For details on the derivation of the Legendre transform of the norm, see e.g. - the reference :cite:`boyd:04`, p.93/94. - Args: p: Power of the p-norm. """ @@ -161,6 +158,10 @@ def h(self, z: jnp.ndarray) -> float: return 0.5 * jnp.linalg.norm(z, self.p) ** 2 def h_legendre(self, z: jnp.ndarray) -> float: + """Legendre transform of :func:`h`. + + For details on the derivation, see e.g., :cite:`boyd:04`, p. 93/94. + """ return 0.5 * jnp.linalg.norm(z, self.q) ** 2 def tree_flatten(self): @@ -213,7 +214,7 @@ class Euclidean(CostFn): because the function is not strictly convex (it is linear on rays). """ - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute Euclidean norm.""" return jnp.linalg.norm(x - y) @@ -226,7 +227,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: """Compute squared Euclidean norm for vector.""" return jnp.sum(x ** 2, axis=-1) - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) @@ -253,14 +254,14 @@ def __init__(self, ridge: float = 1e-8): super().__init__() self._ridge = ridge - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Cosine distance between vectors, denominator regularized with ridge.""" ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) y_norm = jnp.linalg.norm(y, axis=-1) cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge) cosine_distance = 1.0 - cosine_similarity - # similarity is in [-1, 1], clip because of numerical imprecisions + # similarity is in [-1, 1], clip because of numerical imprecision return jnp.clip(cosine_distance, 0., 2.) @classmethod @@ -268,13 +269,207 @@ def _padder(cls, dim: int) -> jnp.ndarray: return jnp.ones((1, dim)) +class RegTICost(TICost, abc.ABC): + r"""Base class for regularized translation-invariant costs. + + .. math:: + + \frac{1}{2} \|\cdot\|_2^2 + reg\left(\cdot\right) + + where :func:`reg` is the regularization function. + """ + + @abc.abstractmethod + def reg(self, z: jnp.ndarray) -> float: + """Regularization function.""" + + def prox_reg(self, z: jnp.ndarray) -> jnp.ndarray: + """Proximal operator of :func:`reg`.""" + raise NotImplementedError("Proximal operator is not implemented.") + + def h(self, z: jnp.ndarray) -> float: + return 0.5 * jnp.linalg.norm(z, ord=2) ** 2 + self.reg(z) + + def h_legendre(self, z: jnp.ndarray) -> float: + q = jax.lax.stop_gradient(self.prox_reg(z)) + return jnp.sum(q * z) - self.h(q) + + +@jax.tree_util.register_pytree_node_class +class ElasticL1(RegTICost): + r"""Cost inspired by elastic net :cite:`zou:05` regularization. + + .. math:: + + \frac{1}{2} \|\cdot\|_2^2 + \gamma \|\cdot\|_1 + + Args: + gamma: Strength of the :math:`\|\cdot\|_1` regularization. + """ + + def __init__(self, gamma: float = 1.0): + super().__init__() + assert gamma >= 0, "Gamma must be non-negative." + self.gamma = gamma + + def reg(self, z: jnp.ndarray) -> float: + return self.gamma * jnp.linalg.norm(z, ord=1) + + def prox_reg(self, z: jnp.ndarray) -> float: + return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma) + + def tree_flatten(self): + return (), (self.gamma,) + + @classmethod + def tree_unflatten(cls, aux_data, children): + del children + return cls(*aux_data) + + +@jax.tree_util.register_pytree_node_class +class ElasticSTVS(RegTICost): + r"""Cost with soft thresholding operator with vanishing shrinkage (STVS) + :cite:`schreck:15` regularization. + + .. math:: + + \frac{1}{2} \|\cdot\|_2^2 + \gamma^2\mathbf{1}_d^T\left(\sigma(\cdot) - + \frac{1}{2} \exp\left(-2\sigma(\cdot)\right) + \frac{1}{2}\right) + + where :math:`\sigma(\cdot) := \text{asinh}\left(\frac{\cdot}{2\gamma}\right)` + + Args: + gamma: Strength of the STVS regularization. + """ # noqa + + def __init__(self, gamma: float = 1.0): + super().__init__() + assert gamma > 0, "Gamma must be positive." + self.gamma = gamma + + def reg(self, z: jnp.ndarray) -> float: + u = jnp.arcsinh(jnp.abs(z) / (2 * self.gamma)) + out = u - 0.5 * jnp.exp(-2.0 * u) + return (self.gamma ** 2) * jnp.sum(out + 0.5) # make positive + + def prox_reg(self, z: jnp.ndarray) -> float: + return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z + + def tree_flatten(self): + return (), (self.gamma,) + + @classmethod + def tree_unflatten(cls, aux_data, children): + del children + return cls(*aux_data) + + +@jax.tree_util.register_pytree_node_class +class ElasticSqKOverlap(RegTICost): + r"""Cost with squared k-overlap norm regularization :cite:`argyriou:12`. + + .. math:: + + \frac{1}{2} \|\cdot\|_2^2 + \frac{1}{2} \gamma \|\cdot\|_{ovk}^2 + + where :math:`\|\cdot\|_{ovk}^2` is the squared k-overlap norm, + see def. 2.1 of :cite:`argyriou:12`. + + Args: + k: Number of groups. Must be in ``[0, d)`` where :math:`d` is the + dimensionality of the data. + gamma: Strength of the squared k-overlap norm regularization. + """ + + def __init__(self, k: int, gamma: float = 1.0): + super().__init__() + assert gamma > 0, "Gamma must be positive." + self.k = k + self.gamma = gamma + + def reg(self, z: jnp.ndarray) -> float: + # Prop 2.1 in :cite:`argyriou:12` + k = self.k + top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values + top_w = jnp.flip(top_w) # Sort k-largest from smallest to largest + # sum (dim - k) smallest values + sum_bottom = jnp.sum(jnp.abs(z)) - jnp.sum(top_w) + cumsum_top = jnp.cumsum(top_w) + # Cesaro mean of top_w (each term offset with sum_bottom). + cesaro = sum_bottom + cumsum_top + cesaro /= jnp.arange(k) + 1 + # Choose first index satisfying constraint in Prop 2.1 + lower_bound = cesaro - top_w >= 0 + # Last upper bound is always True. + upper_bound = jnp.concatenate(((top_w[1:] - cesaro[:-1] > 0), + jnp.array((True,)))) + r = jnp.argmax(lower_bound * upper_bound) + s = jnp.sum(jnp.where(jnp.arange(k) < k - r - 1, jnp.flip(top_w) ** 2, 0)) + + return 0.5 * self.gamma * (s + (r + 1) * cesaro[r] ** 2) + + def prox_reg(self, z: jnp.ndarray) -> float: + + @functools.partial(jax.vmap, in_axes=[0, None, None]) + def find_indices(r: int, l: jnp.ndarray, + z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + + @functools.partial(jax.vmap, in_axes=[None, 0, None]) + def inner(r: int, l: int, + z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + i = k - r - 1 + res = jnp.sum(z * ((i <= ixs) & (ixs < l))) + res /= l - k + (beta + 1) * r + beta + 1 + + cond1_left = jnp.logical_or(i == 0, (z[i - 1] / beta + 1) > res) + cond1_right = res >= (z[i] / (beta + 1)) + cond1 = jnp.logical_and(cond1_left, cond1_right) + + cond2_left = z[l - 1] > res + cond2_right = jnp.logical_or(l == d, res >= z[l]) + cond2 = jnp.logical_and(cond2_left, cond2_right) + + return res, cond1 & cond2 + + return inner(r, l, z) + + # Alg. 1 of :cite:`argyriou:12` + k, d, beta = self.k, z.shape[-1], 1.0 / self.gamma + + ixs = jnp.arange(d) + z, sgn = jnp.abs(z), jnp.sign(z) + z_ixs = jnp.argsort(z)[::-1] + z_sorted = z[z_ixs] + + # (k, d - k + 1) + T, mask = find_indices(jnp.arange(k), jnp.arange(k, d + 1), z_sorted) + (r,), (l,) = jnp.where(mask, size=1) # size=1 for jitting + T = T[r, l] + + q1 = (beta / (beta + 1)) * z_sorted * (ixs < (k - r - 1)) + q2 = (z_sorted - T) * jnp.logical_and((k - r - 1) <= ixs, ixs < (l + k)) + q = q1 + q2 + + # change sign and reorder + return sgn * q[jnp.argsort(z_ixs.astype(float))] + + def tree_flatten(self): + return (), (self.k, self.gamma) + + @classmethod + def tree_unflatten(cls, aux_data, children): + del children + return cls(*aux_data) + + @jax.tree_util.register_pytree_node_class class Bures(CostFn): """Bures distance between a pair of (mean, cov matrix) raveled as vectors. Args: dimension: Dimensionality of the data. - kwargs: Keyword arguments for :func:`ott.math.matrix_square_root.sqrtm`. + kwargs: Keyword arguments for :func:`~ott.math.matrix_square_root.sqrtm`. """ def __init__(self, dimension: int, **kwargs: Any): @@ -289,7 +484,7 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray: norm += jnp.trace(cov, axis1=-2, axis2=-1) return norm - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute - 2 x Bures dot-product.""" mean_x, cov_x = x_to_means_and_covs(x, self._dimension) mean_y, cov_y = x_to_means_and_covs(y, self._dimension) @@ -385,8 +580,8 @@ def barycenter( covariance (raveled). kwargs: Passed on to :meth:`covariance_fixpoint_iter`, and by extension to :func:`ott.math.matrix_square_root.sqrtm`. Note that `tolerance` is used - for the fixed-point iteration of the barycenter, whereas `threshold` will apply to the fixed - point iteration of Newton-Schulz iterations. + for the fixed-point iteration of the barycenter, whereas `threshold` + will apply to the fixed point iteration of Newton-Schulz iterations. Returns: A concatenation of the mean and the raveled covariance of the barycenter. @@ -460,7 +655,7 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray: """ return self._gamma * x[..., 0] - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute dot-product for unbalanced Bures. Args: diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 5b18e1e04..ec206476d 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for the cost/norm functions.""" +from typing import Type import pytest @@ -19,7 +20,8 @@ import jax.numpy as jnp import numpy as np -from ott.geometry import costs +from ott.geometry import costs, pointcloud +from ott.solvers.linear import sinkhorn @pytest.mark.fast @@ -70,27 +72,113 @@ def test_cosine(self, rng: jnp.ndarray): atol=1e-5, ) - @pytest.mark.fast - class TestBuresBarycenter: - - def test_buresb(self, rng: jnp.ndarray): - d = 5 - r = jnp.array([0.3206, 0.8825, 0.1113, 0.00052, 0.9454]) - Sigma1 = r * jnp.eye(d) - s = jnp.array([0.3075, 0.8545, 0.1110, 0.0054, 0.9206]) - Sigma2 = s * jnp.eye(d) - # initializing Bures cost function - weights = jnp.array([.3, .7]) - bures = costs.Bures(d) - # stacking parameter values - xs = jnp.vstack(( - costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma1, d), - costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma2, d) - )) - - output = bures.barycenter(weights, xs, tolerance=1e-4, threshold=1e-6) - _, sigma = costs.x_to_means_and_covs(output, 5) - ground_truth = (weights[0] * jnp.sqrt(r) + weights[1] * jnp.sqrt(s)) ** 2 + +@pytest.mark.fast +class TestBuresBarycenter: + + def test_bures(self, rng: jnp.ndarray): + d = 5 + r = jnp.array([0.3206, 0.8825, 0.1113, 0.00052, 0.9454]) + Sigma1 = r * jnp.eye(d) + s = jnp.array([0.3075, 0.8545, 0.1110, 0.0054, 0.9206]) + Sigma2 = s * jnp.eye(d) + # initializing Bures cost function + weights = jnp.array([.3, .7]) + bures = costs.Bures(d) + # stacking parameter values + xs = jnp.vstack(( + costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma1, d), + costs.mean_and_cov_to_x(jnp.zeros((d,)), Sigma2, d) + )) + + output = bures.barycenter(weights, xs, tolerance=1e-4, threshold=1e-6) + _, sigma = costs.x_to_means_and_covs(output, 5) + ground_truth = (weights[0] * jnp.sqrt(r) + weights[1] * jnp.sqrt(s)) ** 2 + np.testing.assert_allclose( + ground_truth, jnp.diag(sigma), rtol=1e-5, atol=1e-5 + ) + + +@pytest.mark.fast +class TestRegTICost: + + @pytest.mark.parametrize( + "cost_fn", + [ + costs.ElasticL1(gamma=5), + costs.ElasticL1(gamma=0.0), + costs.ElasticSTVS(gamma=2.2), + costs.ElasticSTVS(gamma=10), + ], + ids=[ + "elasticnet", + "elasticnet-gam0", + "stvs-gam2.2", + "stvs-gam10", + ], + ) + def test_reg_cost_legendre( + self, rng: jax.random.PRNGKeyArray, cost_fn: costs.RegTICost + ): + for d in [5, 10, 50, 100, 1000]: + rng, rng1 = jax.random.split(rng) + expected = jax.random.normal(rng1, (d,)) + actual = jax.grad(cost_fn.h_legendre)(jax.grad(cost_fn.h)(expected)) np.testing.assert_allclose( - ground_truth, jnp.diag(sigma), rtol=1e-5, atol=1e-5 + actual, expected, rtol=1e-5, atol=1e-5, err_msg=f"d={d}" ) + + @pytest.mark.parametrize("k", [1, 2, 7, 10]) + @pytest.mark.parametrize("d", [10, 50, 100]) + def test_elastic_sq_k_overlap(self, rng: jax.random.PRNGKey, k: int, d: int): + expected = jax.random.normal(rng, (d,)) + + cost_fn = costs.ElasticSqKOverlap(k=k, gamma=1e-2) + actual = jax.grad(cost_fn.h_legendre)(jax.grad(cost_fn.h)(expected)) + # should hold for small gamma + assert np.corrcoef(expected, actual)[0][1] > 0.97 + + @pytest.mark.parametrize( + "cost_fn", [ + costs.ElasticL1(gamma=100), + costs.ElasticSTVS(gamma=10), + costs.ElasticSqKOverlap(k=3, gamma=20) + ] + ) + def test_sparse_displacement( + self, rng: jax.random.PRNGKeyArray, cost_fn: costs.RegTICost + ): + frac_sparse = 0.8 + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, (50, 30)) + y = jax.random.normal(key2, (71, 30)) + geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn) + + dp = sinkhorn.solve(geom).to_dual_potentials() + + for arr, fwd in zip([x, y], [True, False]): + arr_t = dp.transport(arr, forward=fwd) + assert np.sum(np.isclose(arr, arr_t)) / arr.size > frac_sparse + + @pytest.mark.parametrize("cost_clazz", [costs.ElasticL1, costs.ElasticSTVS]) + def test_stronger_regularization_increases_sparsity( + self, rng: jax.random.PRNGKeyArray, cost_clazz: Type[costs.RegTICost] + ): + d, keys = 30, jax.random.split(rng, 4) + x = jax.random.normal(keys[0], (50, d)) + y = jax.random.normal(keys[1], (71, d)) + xx = jax.random.normal(keys[2], (25, d)) + yy = jax.random.normal(keys[3], (35, d)) + + sparsity = {False: [], True: []} + for gamma in [9, 10, 100]: + cost_fn = cost_clazz(gamma=gamma) + geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn) + + dp = sinkhorn.solve(geom).to_dual_potentials() + for arr, fwd in zip([xx, yy], [True, False]): + arr_t = dp.transport(arr, forward=True) + sparsity[fwd].append(np.sum(np.isclose(arr, arr_t))) + + for fwd in [False, True]: + np.testing.assert_array_equal(np.diff(sparsity[fwd]) > 0.0, True)