Skip to content

Commit

Permalink
Added CDF distance to Univariate Solver (#451)
Browse files Browse the repository at this point in the history
* test commit

* added HistogramTransport distance

* added tests

* removed `test_file`

* renamed `ArbitraryTransportInitializer`
to `FixedCouplingInitializer`

* softness -> epsilon_1d

* epsilon_1d=0.0 for hard sorting

* epsilon_1d=0.0 for hard sorting

* + cost_fn for 1d_wasserstein

* extracted `wasserstein_1d`

* removed `p` argument

* removed `match` statement

* removed `HTOutput` and `HTState` classes

* updated `ht_test`

* fixed `ht_test`

* `wasserstein_1d` -> `univariate`

* fixed indentation issues

* removed `FixedCouplingInitializer` class

* changed `QuadraticInitializer` documentation

* added `solvers.univariate` to documentation

* minor edits to `univariate.py`

* fixed `UnivariateSolver` docstring

* many updates to `univariate.py`

* docstring edits to `histogram_transport`

* added missing type of `univariate`'s `__call__`

* added pytree class to HT and Univariate solvers

* doc changes, code refactoring

* added memoli citation

* parametrized `ht_test`

* fixed spelling

* readded min/max iterations to `univariate`

* fixed  underline

* fixed indentations

* added `init_coupling` as a child

* type ascription for `**kwargs`

* fixed `warning::`, I think?

* docstring edits of `univariate.py`

* `self.cost_fn` to oneliner

* fixed `univariate` children

* fixing `.rst` stuff

* editing `univariate.py` docs

* slightly more documentation

* fixed `ht_test` error

* Use `sort_fn`

* Fewer tests

* Add shape checks

* Add diff tests

* Re-scale when subsampling

* Update grad test

* Rename solver

* Fix indentation

* Refer to the definition in the LowerBoundSolver

* added univariate._cdf_distance

* Histogram Transport -> Lower Bound

* factored out an unnecessary `argsort`

* capitalize CDF

* add univariate test

* remove `p`

* remove `p` edit

* add test `scipy` matches `ott`

* type ascription for `cost_fn`

* one less `rng`

* uniformize in `_cdf_distance`

* fix test parameterization

* add grad testing

* fixed documentation

* fixing tests

* validate probability distribution

* added back 5th `rng`

* fix sinkhorn cost_fn

* fixed boolean error

* fix shape bug

* refactored uniform checking

* jitted sinkhorn solver

* added todo

* removed `fast` mark on sinkhorn test

* fix parametrization

* lowered iterations in sinkhorn test

* fixed `ones_like`

---------

Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com>
  • Loading branch information
Daniel-Packer and michalk8 authored Nov 7, 2023
1 parent 62b07e2 commit f139415
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 3 deletions.
47 changes: 45 additions & 2 deletions src/ott/solvers/linear/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class UnivariateSolver:
- `'quantile'` - Take equally spaced quantiles of the distances.
- `'equal'` - No subsampling is performed, requires distributions to have
the same number of points.
- `'wasserstein'` - Compute the distance using the explicit solution
involving inverse CDFs.
n_subsamples: The number of samples to draw for the "quantile" or
"subsample" methods.
Expand All @@ -53,20 +55,31 @@ def __init__(
self,
sort_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
cost_fn: Optional[costs.CostFn] = None,
method: Literal["subsample", "quantile", "equal"] = "subsample",
method: Literal["subsample", "quantile", "wasserstein",
"equal"] = "subsample",
n_subsamples: int = 100,
):
self.sort_fn = jnp.sort if sort_fn is None else sort_fn
self.cost_fn = costs.PNormP(2) if cost_fn is None else cost_fn
self.method = method
self.n_subsamples = n_subsamples

def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(
self,
x: jnp.ndarray,
y: jnp.ndarray,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None
) -> float:
"""Computes the Univariate OT Distance between `x` and `y`.
Args:
x: The first distribution of shape ``[n,]`` or ``[n, 1]``.
y: The second distribution of shape ``[m,]`` or ``[m, 1]``.
a: The first marginals when ``method = 'wasserstein'``. If :obj:`None`,
uniform will be used.
b: The second marginals when ``method = 'wasserstein'``. If :obj:`None`,
uniform will be used.
Returns:
The OT distance.
Expand All @@ -91,12 +104,42 @@ def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y)
xx = jnp.quantile(sorted_x, q=jnp.linspace(0, 1, self.n_subsamples))
yy = jnp.quantile(sorted_y, q=jnp.linspace(0, 1, self.n_subsamples))
elif self.method == "wasserstein":
return self._cdf_distance(x, y, a, b)
else:
raise NotImplementedError(f"Method `{self.method}` not implemented.")

# re-scale when subsampling
return self.cost_fn.pairwise(xx, yy) * (n / xx.shape[0])

def _cdf_distance(
self, x: jnp.ndarray, y: jnp.ndarray, a: Optional[jnp.ndarray],
b: Optional[jnp.ndarray]
):
# Implementation based on `scipy` implementation for
# :func:<scipy.stats.wasserstein_distance>
a = jnp.ones_like(x) if a is None else a
a /= jnp.sum(a)
b = jnp.ones_like(y) if b is None else b
b /= jnp.sum(b)

all_values = jnp.concatenate([x, y])
all_values_sorter = jnp.argsort(all_values)
all_values_sorted = all_values[all_values_sorter]
x_pdf = jnp.concatenate([a, jnp.zeros(y.shape)])[all_values_sorter]
y_pdf = jnp.concatenate([jnp.zeros(x.shape), b])[all_values_sorter]

x_cdf = jnp.cumsum(x_pdf)
y_cdf = jnp.cumsum(y_pdf)

quantiles = jnp.sort(jnp.concatenate([x_cdf, y_cdf]))
x_cdf_inv = all_values_sorted[jnp.searchsorted(x_cdf, quantiles)]
y_cdf_inv = all_values_sorted[jnp.searchsorted(y_cdf, quantiles)]
return jnp.sum(
jax.vmap(self.cost_fn)(y_cdf_inv[1:, None], x_cdf_inv[1:, None]) *
jnp.diff(quantiles)
)

def tree_flatten(self): # noqa: D102
aux_data = vars(self).copy()
return [], aux_data
Expand Down
131 changes: 131 additions & 0 deletions tests/solvers/linear/univariate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import scipy as sp
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, univariate


class TestUnivariate:

@pytest.fixture(autouse=True)
def initialize(self, rng: jax.Array):
self.rng = rng
self.n = 17
self.m = 29
self.rng, *rngs = jax.random.split(self.rng, 5)
self.x = jax.random.uniform(rngs[0], [self.n])
self.y = jax.random.uniform(rngs[1], [self.m])
a = jax.random.uniform(rngs[2], [self.n])
b = jax.random.uniform(rngs[3], [self.m])

# adding zero weights to test proper handling
a = a.at[0].set(0)
b = b.at[3].set(0)
self.a = a / jnp.sum(a)
self.b = b / jnp.sum(b)

@pytest.mark.parametrize(
"cost_fn", [
costs.SqEuclidean(),
costs.PNormP(1.0),
costs.PNormP(2.0),
costs.PNormP(1.7)
]
)
def test_cdf_distance_and_sinkhorn(self, cost_fn: costs.CostFn):
"""The Univariate distance coincides with the sinkhorn solver"""
univariate_solver = univariate.UnivariateSolver(
method="wasserstein", cost_fn=cost_fn
)
distance = univariate_solver(self.x, self.y, self.a, self.b)

geom = pointcloud.PointCloud(
x=self.x[:, None], y=self.y[:, None], cost_fn=cost_fn, epsilon=5e-3
)
prob = linear_problem.LinearProblem(geom, a=self.a, b=self.b)
sinkhorn_solver = jax.jit(sinkhorn.Sinkhorn(max_iterations=10_000))
sinkhorn_soln = sinkhorn_solver(prob)

np.testing.assert_allclose(
sinkhorn_soln.primal_cost, distance, atol=0, rtol=1e-1
)

@pytest.mark.fast()
def test_cdf_distance_and_scipy(self):
"""The OTT solver coincides with scipy solver"""

# The `scipy` solver only has the solution for p=1.0 visible
univariate_solver = univariate.UnivariateSolver(
method="wasserstein", cost_fn=costs.PNormP(1.0)
)
ott_distance = univariate_solver(self.x, self.y, self.a, self.b)

scipy_distance = sp.stats.wasserstein_distance(
self.x, self.y, self.a, self.b
)

np.testing.assert_allclose(scipy_distance, ott_distance, atol=0, rtol=1e-2)

@pytest.mark.fast()
def test_cdf_grad(
self,
rng: jax.Array,
):
# TODO: Once a `check_grad` function is implemented, replace the code
# blocks before with `check_grad`'s.
cost_fn = costs.SqEuclidean()
rngs = jax.random.split(rng, 4)
eps, tol = 1e-4, 1e-3

solver = univariate.UnivariateSolver(method="wasserstein", cost_fn=cost_fn)

grad_x, grad_y, grad_a, grad_b = jax.jit(jax.grad(solver, (0, 1, 2, 3))
)(self.x, self.y, self.a, self.b)

# Checking geometric grads:
v_x = jax.random.normal(rngs[0], shape=self.x.shape)
v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps
expected = solver(self.x + v_x, self.y, self.a,
self.b) - solver(self.x - v_x, self.y, self.a, self.b)
actual = 2.0 * jnp.vdot(v_x, grad_x)
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol)

v_y = jax.random.normal(rngs[1], shape=self.y.shape)
v_y = (v_y / jnp.linalg.norm(v_y, axis=-1, keepdims=True)) * eps
expected = solver(self.x, self.y + v_y, self.a,
self.b) - solver(self.x, self.y - v_y, self.a, self.b)
actual = 2.0 * jnp.vdot(v_y, grad_y)
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol)

# Checking probability grads:
v_a = jax.random.normal(rngs[2], shape=self.x.shape)
v_a -= jnp.mean(v_a, axis=-1, keepdims=True)
v_a = (v_a / jnp.linalg.norm(v_a, axis=-1, keepdims=True)) * eps
expected = solver(self.x, self.y, self.a + v_a,
self.b) - solver(self.x, self.y, self.a - v_a, self.b)
actual = 2.0 * jnp.vdot(v_a, grad_a)
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol)

v_b = jax.random.normal(rngs[3], shape=self.y.shape)
v_b -= jnp.mean(v_b, axis=-1, keepdims=True)
v_b = (v_b / jnp.linalg.norm(v_b, axis=-1, keepdims=True)) * eps
expected = solver(self.x, self.y, self.a, self.b +
v_b) - solver(self.x, self.y, self.a, self.b - v_b)
actual = 2.0 * jnp.vdot(v_b, grad_b)
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol)
2 changes: 1 addition & 1 deletion tests/solvers/quadratic/lower_bound_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def initialize(self, rng: jax.Array):
rngs = jax.random.split(rng, 4)
self.x = jax.random.uniform(rngs[0], (self.n, d_x))
self.y = jax.random.uniform(rngs[1], (self.m, d_y))
# Currently Histogram Transport only supports uniform distributions:
# Currently the Lower Bound only supports uniform distributions:
a = jnp.ones(self.n)
b = jnp.ones(self.m)
self.a = a / jnp.sum(a)
Expand Down

0 comments on commit f139415

Please sign in to comment.