Skip to content

Commit

Permalink
Introduce primal_cost and dual_cost for Sinkhorn outputs (only pr…
Browse files Browse the repository at this point in the history
…imal for LR) for arbitrary geometries. (#184)

* instantiate ot_cost for arbitrary geometry.

* lint

* fix assert in rank for to_LRCGeometry method.

* fixes

* linter

* linter

* fix extra bit of memory.

* _check_LRC_dim becomes private.

* bug

* linters and comments.

* doc

* change naming, introduce dual_cost

* linter
  • Loading branch information
marcocuturi authored Nov 28, 2022
1 parent be94b74 commit c00bc12
Show file tree
Hide file tree
Showing 14 changed files with 368 additions and 113 deletions.
129 changes: 77 additions & 52 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ def shape(self) -> Tuple[int, int]:
return mat.shape
return 0, 0

@property
def can_LRC(self) -> bool:
"""Check quickly if casting geometry as LRC makes sense.
This check is only carried out using basic considerations from the geometry,
not using a rigorous check involving, e.g., svd.
"""
return False

@property
def is_squared_euclidean(self) -> bool:
"""Whether cost is computed by taking squared-Eucl. distance of points."""
Expand Down Expand Up @@ -619,13 +628,16 @@ def prepare_divergences(

def to_LRCGeometry(
self,
rank: int,
rank: int = 0,
tol: float = 1e-2,
seed: int = 0,
scale: float = 1.
) -> 'low_rank.LRCGeometry':
r"""Factorize the cost matrix in sublinear time :cite:`indyk:19`.
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
When `rank=min(n,m)` or `0` (by default), use :func:`jax.numpy.linalg.svd`.
For other values, use the routine in sublinear time :cite:`indyk:19`.
Uses the implementation of :cite:`scetbon:21`, algorithm 4.
It holds that with probability *0.99*,
Expand All @@ -645,59 +657,72 @@ def to_LRCGeometry(
Low-rank geometry.
"""
from ott.geometry import low_rank

assert rank > 0, f"Rank must be positive, got {rank}."
rng = jax.random.PRNGKey(seed)
key1, key2, key3, key4, key5 = jax.random.split(rng, 5)
assert rank >= 0, f"Rank must be non-negative, got {rank}."
n, m = self.shape
n_subset = min(int(rank / tol), n, m)

i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m)

# force `batch_size=None` since `cost_matrix` would be `None`
ci_star = self.subset(
i_star, None, batch_size=None
).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(
None, j_star, batch_size=None
).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
S = self.subset(row_ixs, None, batch_size=None).cost_matrix
S /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(S ** 2, axis=0) # (m,)
p_col /= jnp.sum(p_col)
# (n_subset,)
col_ixs = jax.random.choice(key4, m, shape=(n_subset,), p=p_col)
# (n_subset, n_subset)
W = S[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :])

U, _, V = jsp.linalg.svd(W)
U = U[:, :rank] # (n_subset, rank)
U = (S.T @ U) / jnp.linalg.norm(W.T @ U, axis=0) # (m, rank)

_, d, v = jnp.linalg.svd(U.T @ U) # (k,), (k, k)
v = v.T / jnp.sqrt(d)[None, :]

inv_scale = (1. / jnp.sqrt(n_subset))
col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(
None, col_ixs, batch_size=None
).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)

if rank == 0 or rank >= min(n, m):
# TODO(marcocuturi): add hermitian=self.is_symmetric, currently bugging.
u, s, vh = jnp.linalg.svd(
self.cost_matrix,
full_matrices=False,
compute_uv=True,
)

cost_1 = u
cost_2 = (s[:, None] * vh).T
else:
rng = jax.random.PRNGKey(seed)
key1, key2, key3, key4, key5 = jax.random.split(rng, 5)
n_subset = min(int(rank / tol), n, m)

i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m)

# force `batch_size=None` since `cost_matrix` would be `None`
ci_star = self.subset(
i_star, None, batch_size=None
).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(
None, j_star, batch_size=None
).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
s = self.subset(row_ixs, None, batch_size=None).cost_matrix
s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(s ** 2, axis=0) # (m,)
p_col /= jnp.sum(p_col)
# (n_subset,)
col_ixs = jax.random.choice(key4, m, shape=(n_subset,), p=p_col)
# (n_subset, n_subset)
w = s[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :])

U, _, V = jsp.linalg.svd(w)
U = U[:, :rank] # (n_subset, rank)
U = (s.T @ U) / jnp.linalg.norm(w.T @ U, axis=0) # (m, rank)

_, d, v = jnp.linalg.svd(U.T @ U) # (k,), (k, k)
v = v.T / jnp.sqrt(d)[None, :]

inv_scale = (1. / jnp.sqrt(n_subset))
col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(
None, col_ixs, batch_size=None
).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)
cost_1 = V
cost_2 = U

return low_rank.LRCGeometry(
cost_1=V,
cost_2=U,
cost_1=cost_1,
cost_2=cost_2,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
Expand Down
5 changes: 3 additions & 2 deletions ott/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ def apply_transport_from_potentials(
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Not implemented."""
raise ValueError("Not implemented.")
"""Since applying from potentials is not feasible in grids, use scalings."""
u, v = self.scaling_from_potential(f), self.scaling_from_potential(g)
return self.apply_transport_from_scalings(u, v, vec, axis=axis)

def marginal_from_potentials(
self,
Expand Down
113 changes: 93 additions & 20 deletions ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, geometry, pointcloud
from ott.geometry import costs, geometry, low_rank, pointcloud
from ott.math import utils

__all__ = ["Grid"]
Expand Down Expand Up @@ -115,31 +115,28 @@ def __init__(
super().__init__(**kwargs)

@property
def cost_matrices(self) -> List[jnp.ndarray]:
def geometries(self) -> List[geometry.Geometry]:
"""Cost matrices along each dimension of the grid."""
cost_matrices = []
geometries = []
for dimension, cost_fn in itertools.zip_longest(
range(self.grid_dimension), self.cost_fns, fillvalue=self.cost_fns[-1]
):
x_values = self.x[dimension][:, jnp.newaxis]
cost_matrices.append(
pointcloud.PointCloud(x_values, cost_fn=cost_fn).cost_matrix
geom = pointcloud.PointCloud(
x_values, cost_fn=cost_fn, epsilon=self._epsilon_init
)
return cost_matrices

@property
def kernel_matrices(self) -> List[jnp.ndarray]:
"""Kernel matrices along each dimension of the grid."""
kernel_matrices = []
for cost_matrix in self.cost_matrices:
kernel_matrices.append(jnp.exp(-cost_matrix / self.epsilon))
return kernel_matrices
geometries.append(geom)
return geometries

@property
def median_cost_matrix(self) -> NoReturn:
"""Not implemented."""
raise NotImplementedError('Median cost not implemented for grids.')

@property
def can_LRC(self) -> bool:
return True

@property
def shape(self) -> Tuple[int, int]:
return self.num_a, self.num_a
Expand Down Expand Up @@ -197,7 +194,7 @@ def _apply_lse_kernel_one_dimension(self, dimension, f, g, eps, vec=None):
f, g = jnp.transpose(f, indices), jnp.transpose(g, indices)
centered_cost = (
f[:, jnp.newaxis, ...] + g[jnp.newaxis, ...] - jnp.expand_dims(
self.cost_matrices[dimension],
self.geometries[dimension].cost_matrix,
axis=tuple(range(2, 1 + self.grid_dimension))
)
) / eps
Expand All @@ -219,8 +216,8 @@ def _apply_cost_to_vec(
The `apply_cost` operation on grids rests on the following identity.
If it were to be cast as a [num_a, num_a] matrix, the corresponding cost
matrix :math:`C` would be a sum of grid_dimension matrices, each of the form
(here for the j-th slice)
matrix :math:`C` would be a sum of `grid_dimension` matrices, each of the
form (here for the j-th slice)
:math:`\tilde{C}_j : = 1_{n_1} \otimes \dots \otimes C_j \otimes 1_{n_d}`
where each :math:`1_{n}` is the :math:`n\times n` square matrix full of 1's.
Expand All @@ -244,7 +241,8 @@ def _apply_cost_to_vec(
vec = jnp.reshape(vec, self.grid_size)
accum_vec = jnp.zeros_like(vec)
indices = list(range(1, self.grid_dimension))
for dimension, cost in enumerate(self.cost_matrices):
for dimension, geom in enumerate(self.geometries):
cost = geom.cost_matrix
ind = indices.copy()
ind.insert(dimension, 0)
if axis == 0:
Expand Down Expand Up @@ -281,10 +279,11 @@ def apply_kernel(
"""
scaling = jnp.reshape(scaling, self.grid_size)
indices = list(range(1, self.grid_dimension))
for dimension, kernel in enumerate(self.kernel_matrices):
for dimension, geom in enumerate(self.geometries):
kernel = geom.kernel_matrix
kernel = kernel if eps is None else kernel ** (self.epsilon / eps)
ind = indices.copy()
ind.insert(dimension, 0)
kernel = kernel if eps is None else kernel ** (self.epsilon / eps)
scaling = jnp.tensordot(
kernel, scaling, axes=([0], [dimension])
).transpose(ind)
Expand All @@ -301,6 +300,17 @@ def transport_from_potentials(
' cloud geometry instead'
)

def apply_transport_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Since applying from potentials is not feasible in grids, use scalings."""
u, v = self.scaling_from_potential(f), self.scaling_from_potential(g)
return self.apply_transport_from_scalings(u, v, vec, axis=axis)

def transport_from_scalings(
self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0
) -> NoReturn:
Expand Down Expand Up @@ -354,3 +364,66 @@ def tree_unflatten(cls, aux_data, children):
return cls(
x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data
)

def to_LRCGeometry(
self,
scale: float = 1.0,
**kwargs: Any,
) -> low_rank.LRCGeometry:
"""Converts grid to low-rank geometry.
Conversion is carried out by taking advantage of the fact that the true cost
matrix of a grid geometry is a sum of kronecker products of local cost
matrices (for each dimension) with matrice of 1's (both on left and right
sides) of varying dimension. Each of the matrices in that sum can be
factorized if each of these cost matrices can be factorized, which we do
by forcing a conversion to a low rank geometry object.
Args:
scale: Value used to rescale the factors of the low-rank geometry.
Useful when this geometry is used in the linear term of fused GW.
kwargs: Keyword arguments, such as ``rank``, to
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry` used when
geometries on each slice are not low-rank.
Returns:
:class:`~ott.geometry.low_rank.LRCGeometry` object.
"""
cost_1 = []
cost_2 = []
for dimension, geom in enumerate(self.geometries):
# An overall low-rank conversion of the cost matrix on a grid, to an
# object of :class:`~ott.geometry.low_rank.LRCGeometry`, necesitates an
# exact low-rank matrix decompisition of the cost matrix of each slice
# of that grid, even if costs on such slices are not low-rank.
# The idea here is that even if the cost matrix on slice `i` is full rank
# `n_i`, we are better off doing 2 redundant `n_i x n_i` matrix products,
# because this is the only way to access to an overall low-rank
# factorization for the entire cost matrix. To get such an exact
# decomposition, the parameter `rank` is set to `0`, triggering a full
# singular value decomposition if needed.
geom = geom.to_LRCGeometry(rank=0, scale=scale, **kwargs)
c_1, c_2 = geom.cost_1, geom.cost_2
l, r = self.grid_size[:dimension], self.grid_size[dimension + 1:]
l = int(np.prod(np.array(l)))
r = int(np.prod(np.array(r)))
cost_1.append(
jnp.kron(jnp.ones((l, 1)), jnp.kron(c_1, jnp.ones((r, 1),)))
)
cost_2.append(
jnp.kron(jnp.ones((l, 1)), jnp.kron(c_2, jnp.ones((r, 1),)))
)
cost_1 = jnp.concatenate(cost_1, axis=-1)
cost_2 = jnp.concatenate(cost_2, axis=-1)

return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
scale_factor=scale,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
src_mask=self.src_mask,
tgt_mask=self.tgt_mask,
**self._kwargs
)
4 changes: 4 additions & 0 deletions ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def to_LRCGeometry(
"""Return self."""
return self

@property
def can_LRC(self):
return True

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
Expand Down
12 changes: 10 additions & 2 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def _norm_y(self) -> Union[float, jnp.ndarray]:
return self.cost_fn.norm(self.y)
return 0.

@property
def can_LRC(self):
return self.is_squared_euclidean and self._check_LRC_dim

@property
def _check_LRC_dim(self):
(n, m), d = self.shape, self.x.shape[1]
return n * m > (n + m) * d

@property
def cost_matrix(self) -> Optional[jnp.ndarray]:
if self.is_online:
Expand Down Expand Up @@ -608,8 +617,7 @@ def to_LRCGeometry(
Otherwise, returns the re-scaled low-rank geometry.
"""
if self.is_squared_euclidean:
(n, m), d = self.shape, self.x.shape[1]
if n * m > (n + m) * d: # here apply_cost using LRCGeometry preferable.
if self._check_LRC_dim:
return self._sqeucl_to_lr(scale)
# we don't update the `scale_factor` because in GW, the linear cost
# is first materialized and then scaled by `fused_penalty` afterwards
Expand Down
Loading

0 comments on commit c00bc12

Please sign in to comment.