Skip to content

Commit

Permalink
Add low-rank kernel geometry (#440)
Browse files Browse the repository at this point in the history
* Add Gaussian kernel

* Make runnable with Sinkhorn

* Fix epsilon

* Use the same random vectors, fix epsilon

* Add arccos kernel

* Add citation

* Fix citation, start working on docs

* Addd LambertW

* Polish docstrings

* Update Lambert W docs

* Format references

* Remove useless TYPE_CHECKING

* Update docstring

* Add test skeletons

* Add rank test

* Update tests

* Start working on arccos cost

* Add arccos to docs

* Add `s=2` option for `Arccos`

* Fix LRK tests

* Fix tests and tree API

* Rename argument

* Add generic order using `jax.grad`

* Test also the implementation of `J(theta)`

* Polish flaky test

* Address comments
  • Loading branch information
michalk8 authored Jan 2, 2024
1 parent a421e86 commit 31b26f0
Show file tree
Hide file tree
Showing 8 changed files with 459 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Geometries
graph.Graph
geodesic.Geodesic
low_rank.LRCGeometry
low_rank.LRKGeometry
epsilon_scheduler.Epsilon

Cost Functions
Expand All @@ -60,6 +61,7 @@ Cost Functions
costs.SqEuclidean
costs.Euclidean
costs.Cosine
costs.Arccos
costs.Bures
costs.UnbalancedBures
costs.ElasticL1
Expand Down
1 change: 1 addition & 0 deletions docs/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ Miscellaneous
utils.norm
utils.logsumexp
utils.softmin
utils.lambertw
36 changes: 36 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,18 @@ @misc{klein:23
year = {2023},
}

@inproceedings{scetbon:20,
author = {Scetbon, Meyer and Cuturi, Marco},
editor = {Larochelle, H. and Ranzato, M. and Hadsell, R. and Balcan, M.F. and Lin, H.},
publisher = {Curran Associates, Inc.},
url = {https://proceedings.neurips.cc/paper_files/paper/2020/file/9bde76f262285bb1eaeb7b40c758b53e-Paper.pdf},
booktitle = {Advances in Neural Information Processing Systems},
pages = {13468--13480},
title = {Linear Time Sinkhorn Divergences using Positive Features},
volume = {33},
year = {2020},
}

@misc{huguet:2023,
author = {Huguet, Guillaume and Tong, Alexander and Zapatero, María Ramos and Tape, Christopher J. and Wolf, Guy and Krishnaswamy, Smita},
eprint = {2211.00805},
Expand All @@ -814,3 +826,27 @@ @misc{huguet:2023
title = {Geodesic Sinkhorn for Fast and Accurate Optimal Transport on Manifolds},
year = {2023},
}

@article{iacono:17,
author = {Iacono, Roberto and Boyd, John P.},
url = {https://doi.org/10.1007/s10444-017-9530-3},
doi = {10.1007/s10444-017-9530-3},
issn = {1572-9044},
journal = {Advances in Computational Mathematics},
number = {6},
pages = {1403--1436},
title = {New approximations to the principal real-valued branch of the Lambert W-function},
volume = {43},
year = {2017},
}

@inproceedings{cho:09,
author = {Cho, Youngmin and Saul, Lawrence},
editor = {Bengio, Y. and Schuurmans, D. and Lafferty, J. and Williams, C. and Culotta, A.},
publisher = {Curran Associates, Inc.},
url = {https://proceedings.neurips.cc/paper_files/paper/2009/file/5751ec3e9a4feab575962e78e006250d-Paper.pdf},
booktitle = {Advances in Neural Information Processing Systems},
title = {Kernel Methods for Deep Learning},
volume = {22},
year = {2009},
}
70 changes: 68 additions & 2 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"Euclidean",
"SqEuclidean",
"Cosine",
"Arccos",
"ElasticL1",
"ElasticL2",
"ElasticSTVS",
Expand Down Expand Up @@ -311,17 +312,82 @@ def __init__(self, ridge: float = 1e-8):

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Cosine distance between vectors, denominator regularized with ridge."""
ridge = self._ridge
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge)
return 1.0 - cosine_similarity

@classmethod
def _padder(cls, dim: int) -> jnp.ndarray:
return jnp.ones((1, dim))


@jax.tree_util.register_pytree_node_class
class Arccos(CostFn):
r"""Arc-cosine cost function :cite:`cho:09`.
The cost is implemented as:
.. math::
c_n(x, y) = -\log(\frac{1}{\pi} \|x\|^n \|y\|^n J_n(\theta))
where :math:`\theta := \arccos(\frac{x \cdot y}{\|x\| \|y\|})` and
:math:`J_n(\theta) := (-1)^n (\sin \theta)^{2n + 1}
(\frac{1}{\sin \theta}\frac{\partial}{\partial \theta})^n
(\frac{\pi - \theta}{\sin \theta})`.
Args:
n: Order of the kernel. For :math:`n > 2`, successive applications of
:func:`~jax.grad` are used to compute the :math:`J_n(\theta)`.
ridge: Ridge regularization.
"""

def __init__(self, n: int, ridge: float = 1e-8):
self.n = n
self._ridge = ridge

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): # noqa: D102
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge)
theta = jnp.arccos(cosine_similarity)

if self.n == 0:
m = 1.0 - theta / jnp.pi
elif self.n == 1:
j = jnp.sin(theta) + (jnp.pi - theta) * jnp.cos(theta)
m = (x_norm * y_norm) * (j / jnp.pi)
elif self.n == 2:
j = 3.0 * jnp.sin(theta) * jnp.cos(theta) + (jnp.pi - theta) * (
1.0 + 2.0 * jnp.cos(theta) ** 2
)
m = (x_norm * y_norm) ** 2 * (j / jnp.pi)
else:
j = self._j(theta) # less optimized version using autodiff
m = (x_norm * y_norm) ** self.n * (j / jnp.pi)

return -jnp.log(m + self._ridge)

@jax.jit
def _j(self, theta: float) -> float:

def f(t: float, i: int) -> float:
if i == 0:
return (jnp.pi - t) / jnp.sin(t)
return jax.grad(f)(t, i - 1) / jnp.sin(t)

n = self.n
return (-1) ** n * jnp.sin(theta) ** (2.0 * n + 1.0) * f(theta, n)

def tree_flatten(self): # noqa: D102
return [], {"n": self.n, "ridge": self._ridge}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(**aux_data)


class RegTICost(TICost, abc.ABC):
r"""Base class for regularized translation-invariant costs.
Expand Down
171 changes: 167 additions & 4 deletions src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import jax
import jax.numpy as jnp

from ott.geometry import geometry
from ott import utils
from ott.geometry import costs, geometry
from ott.math import utils as mu

__all__ = ["LRCGeometry"]
__all__ = ["LRCGeometry", "LRKGeometry"]


@jax.tree_util.register_pytree_node_class
Expand All @@ -33,8 +35,8 @@ class LRCGeometry(geometry.Geometry):
if :math:`C = AB^T` and :math:`D = EF^T` then :math:`C + D = [A,E][B,F]^T`
Args:
cost_1: jnp.ndarray<float>[num_a, r]
cost_2: jnp.ndarray<float>[num_b, r]
cost_1: Array of shape ``[num_a, r]``.
cost_2: Array of shape ``[num_b, r]``.
bias: constant added to entire cost matrix.
scale: Value used to rescale the factors of the low-rank geometry.
scale_cost: option to rescale the cost matrix. Implemented scalings are
Expand Down Expand Up @@ -343,3 +345,164 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
tgt_mask=tgt_mask,
**aux_data
)


@jax.tree_util.register_pytree_node_class
class LRKGeometry(geometry.Geometry):
"""Low-rank kernel geometry.
.. note::
This constructor is not meant to be called by the user,
please use the :meth:`from_pointcloud` method instead.
Args:
k1: Array of shape ``[num_a, r]`` with positive features.
k2: Array of shape ``[num_b, r]`` with positive features.
epsilon: Epsilon regularization.
kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""

def __init__(
self,
k1: jnp.ndarray,
k2: jnp.ndarray,
epsilon: Optional[float] = None,
**kwargs: Any
):
super().__init__(epsilon=epsilon, relative_epsilon=False, **kwargs)
self.k1 = k1
self.k2 = k2

@classmethod
def from_pointcloud(
cls,
x: jnp.ndarray,
y: jnp.ndarray,
*,
kernel: Literal["gaussian", "arccos"],
rank: int = 100,
std: float = 1.0,
n: int = 1,
rng: Optional[jax.Array] = None
) -> "LRKGeometry":
r"""Low-rank kernel approximation :cite:`scetbon:20`.
Args:
x: Array of shape ``[n, d]``.
y: Array of shape ``[m, d]``.
kernel: Type of the kernel to approximate.
rank: Rank of the approximation.
std: Depending on the ``kernel`` approximation:
- ``'gaussian'`` - scale of the Gibbs kernel.
- ``'arccos'`` - standard deviation of the random projections.
n: Order of the arc-cosine kernel, see :cite:`cho:09` for reference.
rng: Random key used for seeding.
Returns:
Low-rank kernel geometry.
"""
rng = utils.default_prng_key(rng)
if kernel == "gaussian":
r = jnp.maximum(
jnp.linalg.norm(x, axis=-1).max(),
jnp.linalg.norm(y, axis=-1).max()
)
k1 = _gaussian_kernel(rng, x, rank, eps=std, R=r)
k2 = _gaussian_kernel(rng, y, rank, eps=std, R=r)
eps = std
elif kernel == "arccos":
k1 = _arccos_kernel(rng, x, rank, n=n, std=std)
k2 = _arccos_kernel(rng, y, rank, n=n, std=std)
eps = 1.0
else:
raise NotImplementedError(kernel)

return cls(k1, k2, epsilon=eps)

def apply_kernel( # noqa: D102
self,
scaling: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
if axis == 0:
return self.k2 @ (self.k1.T @ scaling)
return self.k1 @ (self.k2.T @ scaling)

@property
def kernel_matrix(self) -> jnp.ndarray: # noqa: D102
return self.k1 @ self.k2.T

@property
def cost_matrix(self) -> jnp.ndarray: # noqa: D102
eps = jnp.finfo(self.dtype).tiny
return -self.epsilon * jnp.log(self.kernel_matrix + eps)

@property
def rank(self) -> int: # noqa: D102
return self.k1.shape[1]

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
return self.k1.shape[0], self.k2.shape[0]

@property
def dtype(self) -> jnp.dtype: # noqa: D102
return self.k1.dtype

def transport_from_potentials(
self, f: jnp.ndarray, g: jnp.ndarray
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")

def tree_flatten(self): # noqa: D102
return [self.k1, self.k2, self._epsilon_init], {}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, **aux_data)


def _gaussian_kernel(
rng: jax.Array,
x: jnp.ndarray,
n_features: int,
eps: float,
R: jnp.ndarray,
) -> jnp.ndarray:
_, d = x.shape
cost_fn = costs.SqEuclidean()

y = (R ** 2) / (eps * d)
q = y / (2.0 * mu.lambertw(y))
sigma = jnp.sqrt(q * eps * 0.25)

u = jax.random.normal(rng, shape=(n_features, d)) * sigma
cost = cost_fn.all_pairs(x, u)
norm_u = cost_fn.norm(u)

tmp = -2.0 * (cost / eps) + (norm_u / (eps * q))
phi = (2 * q) ** (d / 4) * jnp.exp(tmp)

return (1.0 / jnp.sqrt(n_features)) * phi


def _arccos_kernel(
rng: jax.Array,
x: jnp.ndarray,
n_features: int,
n: int,
std: float = 1.0,
kappa: float = 1e-6,
) -> jnp.ndarray:
n_points, d = x.shape
c = jnp.sqrt(2) * (std ** (d / 2))

u = jax.random.normal(rng, shape=(n_features, d)) * std
tmp = -(1 / 4) * jnp.sum(u ** 2, axis=-1) * (1.0 - (1.0 / (std ** 2)))
phi = c * (jnp.maximum(0.0, (x @ u.T)) ** n) * jnp.exp(tmp)

return jnp.c_[(1.0 / jnp.sqrt(n_features)) * phi,
jnp.full((n_points,), fill_value=kappa)]
Loading

0 comments on commit 31b26f0

Please sign in to comment.