diff --git a/docs/geometry.rst b/docs/geometry.rst index 788acb8e6..2ee4a72ea 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -47,6 +47,7 @@ Geometries graph.Graph geodesic.Geodesic low_rank.LRCGeometry + low_rank.LRKGeometry epsilon_scheduler.Epsilon Cost Functions @@ -60,6 +61,7 @@ Cost Functions costs.SqEuclidean costs.Euclidean costs.Cosine + costs.Arccos costs.Bures costs.UnbalancedBures costs.ElasticL1 diff --git a/docs/math.rst b/docs/math.rst index df02bd0ac..960ac9e78 100644 --- a/docs/math.rst +++ b/docs/math.rst @@ -40,3 +40,4 @@ Miscellaneous utils.norm utils.logsumexp utils.softmin + utils.lambertw diff --git a/docs/references.bib b/docs/references.bib index c5d4c4678..4497a8f43 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -806,6 +806,18 @@ @misc{klein:23 year = {2023}, } +@inproceedings{scetbon:20, + author = {Scetbon, Meyer and Cuturi, Marco}, + editor = {Larochelle, H. and Ranzato, M. and Hadsell, R. and Balcan, M.F. and Lin, H.}, + publisher = {Curran Associates, Inc.}, + url = {https://proceedings.neurips.cc/paper_files/paper/2020/file/9bde76f262285bb1eaeb7b40c758b53e-Paper.pdf}, + booktitle = {Advances in Neural Information Processing Systems}, + pages = {13468--13480}, + title = {Linear Time Sinkhorn Divergences using Positive Features}, + volume = {33}, + year = {2020}, +} + @misc{huguet:2023, author = {Huguet, Guillaume and Tong, Alexander and Zapatero, MarĂ­a Ramos and Tape, Christopher J. and Wolf, Guy and Krishnaswamy, Smita}, eprint = {2211.00805}, @@ -814,3 +826,27 @@ @misc{huguet:2023 title = {Geodesic Sinkhorn for Fast and Accurate Optimal Transport on Manifolds}, year = {2023}, } + +@article{iacono:17, + author = {Iacono, Roberto and Boyd, John P.}, + url = {https://doi.org/10.1007/s10444-017-9530-3}, + doi = {10.1007/s10444-017-9530-3}, + issn = {1572-9044}, + journal = {Advances in Computational Mathematics}, + number = {6}, + pages = {1403--1436}, + title = {New approximations to the principal real-valued branch of the Lambert W-function}, + volume = {43}, + year = {2017}, +} + +@inproceedings{cho:09, + author = {Cho, Youngmin and Saul, Lawrence}, + editor = {Bengio, Y. and Schuurmans, D. and Lafferty, J. and Williams, C. and Culotta, A.}, + publisher = {Curran Associates, Inc.}, + url = {https://proceedings.neurips.cc/paper_files/paper/2009/file/5751ec3e9a4feab575962e78e006250d-Paper.pdf}, + booktitle = {Advances in Neural Information Processing Systems}, + title = {Kernel Methods for Deep Learning}, + volume = {22}, + year = {2009}, +} diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 9f1a6c3a0..297f1f2c8 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -30,6 +30,7 @@ "Euclidean", "SqEuclidean", "Cosine", + "Arccos", "ElasticL1", "ElasticL2", "ElasticSTVS", @@ -311,10 +312,9 @@ def __init__(self, ridge: float = 1e-8): def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Cosine distance between vectors, denominator regularized with ridge.""" - ridge = self._ridge x_norm = jnp.linalg.norm(x, axis=-1) y_norm = jnp.linalg.norm(y, axis=-1) - cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge) + cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge) return 1.0 - cosine_similarity @classmethod @@ -322,6 +322,72 @@ def _padder(cls, dim: int) -> jnp.ndarray: return jnp.ones((1, dim)) +@jax.tree_util.register_pytree_node_class +class Arccos(CostFn): + r"""Arc-cosine cost function :cite:`cho:09`. + + The cost is implemented as: + + .. math:: + c_n(x, y) = -\log(\frac{1}{\pi} \|x\|^n \|y\|^n J_n(\theta)) + + where :math:`\theta := \arccos(\frac{x \cdot y}{\|x\| \|y\|})` and + :math:`J_n(\theta) := (-1)^n (\sin \theta)^{2n + 1} + (\frac{1}{\sin \theta}\frac{\partial}{\partial \theta})^n + (\frac{\pi - \theta}{\sin \theta})`. + + Args: + n: Order of the kernel. For :math:`n > 2`, successive applications of + :func:`~jax.grad` are used to compute the :math:`J_n(\theta)`. + ridge: Ridge regularization. + """ + + def __init__(self, n: int, ridge: float = 1e-8): + self.n = n + self._ridge = ridge + + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): # noqa: D102 + x_norm = jnp.linalg.norm(x, axis=-1) + y_norm = jnp.linalg.norm(y, axis=-1) + cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge) + theta = jnp.arccos(cosine_similarity) + + if self.n == 0: + m = 1.0 - theta / jnp.pi + elif self.n == 1: + j = jnp.sin(theta) + (jnp.pi - theta) * jnp.cos(theta) + m = (x_norm * y_norm) * (j / jnp.pi) + elif self.n == 2: + j = 3.0 * jnp.sin(theta) * jnp.cos(theta) + (jnp.pi - theta) * ( + 1.0 + 2.0 * jnp.cos(theta) ** 2 + ) + m = (x_norm * y_norm) ** 2 * (j / jnp.pi) + else: + j = self._j(theta) # less optimized version using autodiff + m = (x_norm * y_norm) ** self.n * (j / jnp.pi) + + return -jnp.log(m + self._ridge) + + @jax.jit + def _j(self, theta: float) -> float: + + def f(t: float, i: int) -> float: + if i == 0: + return (jnp.pi - t) / jnp.sin(t) + return jax.grad(f)(t, i - 1) / jnp.sin(t) + + n = self.n + return (-1) ** n * jnp.sin(theta) ** (2.0 * n + 1.0) * f(theta, n) + + def tree_flatten(self): # noqa: D102 + return [], {"n": self.n, "ridge": self._ridge} + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(**aux_data) + + class RegTICost(TICost, abc.ABC): r"""Base class for regularized translation-invariant costs. diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index e759b4cb9..d2d2bfbab 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -16,9 +16,11 @@ import jax import jax.numpy as jnp -from ott.geometry import geometry +from ott import utils +from ott.geometry import costs, geometry +from ott.math import utils as mu -__all__ = ["LRCGeometry"] +__all__ = ["LRCGeometry", "LRKGeometry"] @jax.tree_util.register_pytree_node_class @@ -33,8 +35,8 @@ class LRCGeometry(geometry.Geometry): if :math:`C = AB^T` and :math:`D = EF^T` then :math:`C + D = [A,E][B,F]^T` Args: - cost_1: jnp.ndarray[num_a, r] - cost_2: jnp.ndarray[num_b, r] + cost_1: Array of shape ``[num_a, r]``. + cost_2: Array of shape ``[num_b, r]``. bias: constant added to entire cost matrix. scale: Value used to rescale the factors of the low-rank geometry. scale_cost: option to rescale the cost matrix. Implemented scalings are @@ -343,3 +345,164 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 tgt_mask=tgt_mask, **aux_data ) + + +@jax.tree_util.register_pytree_node_class +class LRKGeometry(geometry.Geometry): + """Low-rank kernel geometry. + + .. note:: + This constructor is not meant to be called by the user, + please use the :meth:`from_pointcloud` method instead. + + Args: + k1: Array of shape ``[num_a, r]`` with positive features. + k2: Array of shape ``[num_b, r]`` with positive features. + epsilon: Epsilon regularization. + kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. + """ + + def __init__( + self, + k1: jnp.ndarray, + k2: jnp.ndarray, + epsilon: Optional[float] = None, + **kwargs: Any + ): + super().__init__(epsilon=epsilon, relative_epsilon=False, **kwargs) + self.k1 = k1 + self.k2 = k2 + + @classmethod + def from_pointcloud( + cls, + x: jnp.ndarray, + y: jnp.ndarray, + *, + kernel: Literal["gaussian", "arccos"], + rank: int = 100, + std: float = 1.0, + n: int = 1, + rng: Optional[jax.Array] = None + ) -> "LRKGeometry": + r"""Low-rank kernel approximation :cite:`scetbon:20`. + + Args: + x: Array of shape ``[n, d]``. + y: Array of shape ``[m, d]``. + kernel: Type of the kernel to approximate. + rank: Rank of the approximation. + std: Depending on the ``kernel`` approximation: + + - ``'gaussian'`` - scale of the Gibbs kernel. + - ``'arccos'`` - standard deviation of the random projections. + n: Order of the arc-cosine kernel, see :cite:`cho:09` for reference. + rng: Random key used for seeding. + + Returns: + Low-rank kernel geometry. + """ + rng = utils.default_prng_key(rng) + if kernel == "gaussian": + r = jnp.maximum( + jnp.linalg.norm(x, axis=-1).max(), + jnp.linalg.norm(y, axis=-1).max() + ) + k1 = _gaussian_kernel(rng, x, rank, eps=std, R=r) + k2 = _gaussian_kernel(rng, y, rank, eps=std, R=r) + eps = std + elif kernel == "arccos": + k1 = _arccos_kernel(rng, x, rank, n=n, std=std) + k2 = _arccos_kernel(rng, y, rank, n=n, std=std) + eps = 1.0 + else: + raise NotImplementedError(kernel) + + return cls(k1, k2, epsilon=eps) + + def apply_kernel( # noqa: D102 + self, + scaling: jnp.ndarray, + eps: Optional[float] = None, + axis: int = 0, + ) -> jnp.ndarray: + if axis == 0: + return self.k2 @ (self.k1.T @ scaling) + return self.k1 @ (self.k2.T @ scaling) + + @property + def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 + return self.k1 @ self.k2.T + + @property + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 + eps = jnp.finfo(self.dtype).tiny + return -self.epsilon * jnp.log(self.kernel_matrix + eps) + + @property + def rank(self) -> int: # noqa: D102 + return self.k1.shape[1] + + @property + def shape(self) -> Tuple[int, int]: # noqa: D102 + return self.k1.shape[0], self.k2.shape[0] + + @property + def dtype(self) -> jnp.dtype: # noqa: D102 + return self.k1.dtype + + def transport_from_potentials( + self, f: jnp.ndarray, g: jnp.ndarray + ) -> jnp.ndarray: + """Not implemented.""" + raise ValueError("Not implemented.") + + def tree_flatten(self): # noqa: D102 + return [self.k1, self.k2, self._epsilon_init], {} + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + return cls(*children, **aux_data) + + +def _gaussian_kernel( + rng: jax.Array, + x: jnp.ndarray, + n_features: int, + eps: float, + R: jnp.ndarray, +) -> jnp.ndarray: + _, d = x.shape + cost_fn = costs.SqEuclidean() + + y = (R ** 2) / (eps * d) + q = y / (2.0 * mu.lambertw(y)) + sigma = jnp.sqrt(q * eps * 0.25) + + u = jax.random.normal(rng, shape=(n_features, d)) * sigma + cost = cost_fn.all_pairs(x, u) + norm_u = cost_fn.norm(u) + + tmp = -2.0 * (cost / eps) + (norm_u / (eps * q)) + phi = (2 * q) ** (d / 4) * jnp.exp(tmp) + + return (1.0 / jnp.sqrt(n_features)) * phi + + +def _arccos_kernel( + rng: jax.Array, + x: jnp.ndarray, + n_features: int, + n: int, + std: float = 1.0, + kappa: float = 1e-6, +) -> jnp.ndarray: + n_points, d = x.shape + c = jnp.sqrt(2) * (std ** (d / 2)) + + u = jax.random.normal(rng, shape=(n_features, d)) * std + tmp = -(1 / 4) * jnp.sum(u ** 2, axis=-1) * (1.0 - (1.0 / (std ** 2))) + phi = c * (jnp.maximum(0.0, (x @ u.T)) ** n) * jnp.exp(tmp) + + return jnp.c_[(1.0 / jnp.sqrt(n_features)) * phi, + jnp.full((n_points,), fill_value=kappa)] diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 3331b3a61..67e39989c 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -29,8 +29,9 @@ "gen_js", "logsumexp", "softmin", - "sort_and_argsort", "barycentric_projection", + "sort_and_argsort", + "lambertw", ] @@ -233,3 +234,65 @@ def sort_and_argsort( i_x = jnp.argsort(x) return x[i_x], i_x return jnp.sort(x), None + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2)) +def lambertw( + z: jnp.ndarray, tol: float = 1e-8, max_iter: int = 100 +) -> jnp.ndarray: + """Principal branch of the + `Lambert W function `_. + + This implementation uses Halley's iteration and the global initialization + proposed in :cite:`iacono:17`, Eq. 20 . + + Args: + z: Array. + tol: Tolerance threshold. + max_iter: Maximum number of iterations. + + Returns: + The Lambert W evaluated at ``z``. + """ # noqa: D205 + + def initial_iacono(x: jnp.ndarray) -> jnp.ndarray: + y = jnp.sqrt(1.0 + jnp.e * x) + num = 1.0 + 1.14956131 * y + denom = 1.0 + 0.45495740 * jnp.log1p(y) + return -1.0 + 2.036 * jnp.log(num / denom) + + def cond_fun(container): + it, converged, _ = container + return jnp.logical_and(jnp.any(~converged), it < max_iter) + + def halley_iteration(container): + it, _, w = container + + # modified from `tensorflow_probability` + f = w - z * jnp.exp(-w) + delta = f / (w + 1.0 - 0.5 * (w + 2.0) * f / (w + 1.0)) + + w_next = w - delta + + not_converged = jnp.abs(delta) <= tol * jnp.abs(w_next) + return it + 1, not_converged, w_next + + w0 = initial_iacono(z) + converged = jnp.zeros_like(w0, dtype=bool) + + _, _, w = jax.lax.while_loop( + cond_fun=cond_fun, body_fun=halley_iteration, init_val=(0, converged, w0) + ) + return w + + +@lambertw.defjvp +def _lambertw_jvp( + tol: float, max_iter: int, primals: Tuple[jnp.ndarray, ...], + tangents: Tuple[jnp.ndarray, ...] +) -> Tuple[jnp.ndarray, jnp.ndarray]: + z, = primals + dz, = tangents + w = lambertw(z, tol=tol, max_iter=max_iter) + pz = jnp.where(z == 0.0, 1.0, w / ((1.0 + w) * z)) + return w, pz * dz diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/lr_cost_test.py similarity index 100% rename from tests/geometry/low_rank_test.py rename to tests/geometry/lr_cost_test.py diff --git a/tests/geometry/lr_kernel_test.py b/tests/geometry/lr_kernel_test.py new file mode 100644 index 000000000..1f0a42e7d --- /dev/null +++ b/tests/geometry/lr_kernel_test.py @@ -0,0 +1,121 @@ +from typing import Literal, Optional + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from ott.geometry import costs, low_rank, pointcloud +from ott.solvers import linear + + +@pytest.mark.fast() +class TestLRCGeometry: + + @pytest.mark.parametrize("std", [1e-1, 1.0, 1e2]) + @pytest.mark.parametrize("kernel", ["gaussian", "arccos"]) + def test_positive_features( + self, rng: jax.Array, kernel: Literal["gaussian", "arccos"], std: float + ): + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (10, 2)) + y = jax.random.normal(rng2, (12, 2)) + rank = 5 + + geom = low_rank.LRKGeometry.from_pointcloud( + x, y, kernel=kernel, std=std, rank=rank + ) + + if kernel == "gaussian": + assert geom.rank == rank + else: + assert geom.rank == rank + 1 + np.testing.assert_array_equal(geom.k1 >= 0.0, True) + np.testing.assert_array_equal(geom.k2 >= 0.0, True) + + @pytest.mark.parametrize("n", [0, 1, 2]) + def test_arccos_j_function(self, rng: jax.Array, n: int): + + def j(theta: float) -> float: + if n == 0: + return jnp.pi - theta + if n == 1: + return jnp.sin(theta) + (jnp.pi - theta) * jnp.cos(theta) + if n == 2: + return 3.0 * jnp.sin(theta) * jnp.cos(theta) + (jnp.pi - theta) * ( + 1.0 + 2.0 * jnp.cos(theta) ** 2 + ) + raise NotImplementedError(n) + + x = jnp.abs(jax.random.normal(rng, (32,))) + cost_fn = costs.Arccos(n) + + gt = jax.vmap(j)(x) + pred = jax.vmap(cost_fn._j)(x) + + np.testing.assert_allclose(gt, pred, rtol=1e-4, atol=1e-4) + + @pytest.mark.parametrize("std", [1e-2, 1e-1, 1.0]) + @pytest.mark.parametrize("kernel", ["gaussian", "arccos"]) + def test_kernel_approximation( + self, rng: jax.Array, kernel: Literal["gaussian", "arccos"], std: float + ): + rng, rng1, rng2 = jax.random.split(rng, 3) + x = jax.random.normal(rng1, (230, 5)) + y = jax.random.normal(rng2, (260, 5)) + n = 1 + + cost_fn = costs.SqEuclidean() if kernel == "gaussian" else costs.Arccos(n) + pc = pointcloud.PointCloud(x, y, epsilon=std, cost_fn=cost_fn) + gt_cost = pc.cost_matrix + + max_abs_diff = [] + for rank in [10, 50, 100, 200]: + rng, rng_approx = jax.random.split(rng, 2) + geom = low_rank.LRKGeometry.from_pointcloud( + x, y, rank=rank, kernel=kernel, std=std, n=n, rng=rng_approx + ) + pred_cost = geom.cost_matrix + max_abs_diff.append(np.max(np.abs(gt_cost - pred_cost))) + + # test higher rank better approximates the cost + np.testing.assert_array_equal(np.diff(max_abs_diff) <= 0.0, True) + + @pytest.mark.parametrize(("kernel", "n", "std"), [("gaussian", None, 1e-2), + ("gaussian", None, 1e-1), + ("arccos", 0, 1.0001), + ("arccos", 1, 2.0), + ("arccos", 2, 1.05)]) + def test_sinkhorn_approximation( + self, + rng: jax.Array, + kernel: Literal["gaussian", "arccos"], + std: float, + n: Optional[int], + ): + rng, rng1, rng2 = jax.random.split(rng, 3) + x = jax.random.normal(rng1, (83, 5)) + x /= jnp.linalg.norm(x, keepdims=True) + y = jax.random.normal(rng2, (96, 5)) + y /= jnp.linalg.norm(y, keepdims=True) + solve_fn = jax.jit(lambda g: linear.solve(g, lse_mode=False)) + + cost_fn = costs.SqEuclidean() if kernel == "gaussian" else costs.Arccos(n) + geom = pointcloud.PointCloud(x, y, epsilon=std, cost_fn=cost_fn) + gt_out = solve_fn(geom) + + cs = [] + for rank in [5, 40, 80]: + rng, rng_approx = jax.random.split(rng, 2) + geom = low_rank.LRKGeometry.from_pointcloud( + x, y, rank=rank, kernel=kernel, std=std, n=n, rng=rng_approx + ) + + pred_out = solve_fn(geom) + cs.append(pred_out.reg_ot_cost) + + diff = np.diff(np.abs(gt_out.reg_ot_cost - np.array(cs))) + try: + # test higher rank better approximates the Sinkhorn solution + np.testing.assert_array_equal(diff <= 0.0, True) + except AssertionError: + np.testing.assert_allclose(diff, 0.0, rtol=1e-3, atol=1e-3)