-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added CDF distance to Univariate Solver (#451)
* test commit * added HistogramTransport distance * added tests * removed `test_file` * renamed `ArbitraryTransportInitializer` to `FixedCouplingInitializer` * softness -> epsilon_1d * epsilon_1d=0.0 for hard sorting * epsilon_1d=0.0 for hard sorting * + cost_fn for 1d_wasserstein * extracted `wasserstein_1d` * removed `p` argument * removed `match` statement * removed `HTOutput` and `HTState` classes * updated `ht_test` * fixed `ht_test` * `wasserstein_1d` -> `univariate` * fixed indentation issues * removed `FixedCouplingInitializer` class * changed `QuadraticInitializer` documentation * added `solvers.univariate` to documentation * minor edits to `univariate.py` * fixed `UnivariateSolver` docstring * many updates to `univariate.py` * docstring edits to `histogram_transport` * added missing type of `univariate`'s `__call__` * added pytree class to HT and Univariate solvers * doc changes, code refactoring * added memoli citation * parametrized `ht_test` * fixed spelling * readded min/max iterations to `univariate` * fixed underline * fixed indentations * added `init_coupling` as a child * type ascription for `**kwargs` * fixed `warning::`, I think? * docstring edits of `univariate.py` * `self.cost_fn` to oneliner * fixed `univariate` children * fixing `.rst` stuff * editing `univariate.py` docs * slightly more documentation * fixed `ht_test` error * Use `sort_fn` * Fewer tests * Add shape checks * Add diff tests * Re-scale when subsampling * Update grad test * Rename solver * Fix indentation * Refer to the definition in the LowerBoundSolver * added univariate._cdf_distance * Histogram Transport -> Lower Bound * factored out an unnecessary `argsort` * capitalize CDF * add univariate test * remove `p` * remove `p` edit * add test `scipy` matches `ott` * type ascription for `cost_fn` * one less `rng` * uniformize in `_cdf_distance` * fix test parameterization * add grad testing * fixed documentation * fixing tests * validate probability distribution * added back 5th `rng` * fix sinkhorn cost_fn * fixed boolean error * fix shape bug * refactored uniform checking * jitted sinkhorn solver * added todo * removed `fast` mark on sinkhorn test * fix parametrization * lowered iterations in sinkhorn test * fixed `ones_like` --------- Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com>
- Loading branch information
1 parent
62b07e2
commit f139415
Showing
3 changed files
with
177 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import pytest | ||
import scipy as sp | ||
from ott.geometry import costs, pointcloud | ||
from ott.problems.linear import linear_problem | ||
from ott.solvers.linear import sinkhorn, univariate | ||
|
||
|
||
class TestUnivariate: | ||
|
||
@pytest.fixture(autouse=True) | ||
def initialize(self, rng: jax.Array): | ||
self.rng = rng | ||
self.n = 17 | ||
self.m = 29 | ||
self.rng, *rngs = jax.random.split(self.rng, 5) | ||
self.x = jax.random.uniform(rngs[0], [self.n]) | ||
self.y = jax.random.uniform(rngs[1], [self.m]) | ||
a = jax.random.uniform(rngs[2], [self.n]) | ||
b = jax.random.uniform(rngs[3], [self.m]) | ||
|
||
# adding zero weights to test proper handling | ||
a = a.at[0].set(0) | ||
b = b.at[3].set(0) | ||
self.a = a / jnp.sum(a) | ||
self.b = b / jnp.sum(b) | ||
|
||
@pytest.mark.parametrize( | ||
"cost_fn", [ | ||
costs.SqEuclidean(), | ||
costs.PNormP(1.0), | ||
costs.PNormP(2.0), | ||
costs.PNormP(1.7) | ||
] | ||
) | ||
def test_cdf_distance_and_sinkhorn(self, cost_fn: costs.CostFn): | ||
"""The Univariate distance coincides with the sinkhorn solver""" | ||
univariate_solver = univariate.UnivariateSolver( | ||
method="wasserstein", cost_fn=cost_fn | ||
) | ||
distance = univariate_solver(self.x, self.y, self.a, self.b) | ||
|
||
geom = pointcloud.PointCloud( | ||
x=self.x[:, None], y=self.y[:, None], cost_fn=cost_fn, epsilon=5e-3 | ||
) | ||
prob = linear_problem.LinearProblem(geom, a=self.a, b=self.b) | ||
sinkhorn_solver = jax.jit(sinkhorn.Sinkhorn(max_iterations=10_000)) | ||
sinkhorn_soln = sinkhorn_solver(prob) | ||
|
||
np.testing.assert_allclose( | ||
sinkhorn_soln.primal_cost, distance, atol=0, rtol=1e-1 | ||
) | ||
|
||
@pytest.mark.fast() | ||
def test_cdf_distance_and_scipy(self): | ||
"""The OTT solver coincides with scipy solver""" | ||
|
||
# The `scipy` solver only has the solution for p=1.0 visible | ||
univariate_solver = univariate.UnivariateSolver( | ||
method="wasserstein", cost_fn=costs.PNormP(1.0) | ||
) | ||
ott_distance = univariate_solver(self.x, self.y, self.a, self.b) | ||
|
||
scipy_distance = sp.stats.wasserstein_distance( | ||
self.x, self.y, self.a, self.b | ||
) | ||
|
||
np.testing.assert_allclose(scipy_distance, ott_distance, atol=0, rtol=1e-2) | ||
|
||
@pytest.mark.fast() | ||
def test_cdf_grad( | ||
self, | ||
rng: jax.Array, | ||
): | ||
# TODO: Once a `check_grad` function is implemented, replace the code | ||
# blocks before with `check_grad`'s. | ||
cost_fn = costs.SqEuclidean() | ||
rngs = jax.random.split(rng, 4) | ||
eps, tol = 1e-4, 1e-3 | ||
|
||
solver = univariate.UnivariateSolver(method="wasserstein", cost_fn=cost_fn) | ||
|
||
grad_x, grad_y, grad_a, grad_b = jax.jit(jax.grad(solver, (0, 1, 2, 3)) | ||
)(self.x, self.y, self.a, self.b) | ||
|
||
# Checking geometric grads: | ||
v_x = jax.random.normal(rngs[0], shape=self.x.shape) | ||
v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps | ||
expected = solver(self.x + v_x, self.y, self.a, | ||
self.b) - solver(self.x - v_x, self.y, self.a, self.b) | ||
actual = 2.0 * jnp.vdot(v_x, grad_x) | ||
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) | ||
|
||
v_y = jax.random.normal(rngs[1], shape=self.y.shape) | ||
v_y = (v_y / jnp.linalg.norm(v_y, axis=-1, keepdims=True)) * eps | ||
expected = solver(self.x, self.y + v_y, self.a, | ||
self.b) - solver(self.x, self.y - v_y, self.a, self.b) | ||
actual = 2.0 * jnp.vdot(v_y, grad_y) | ||
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) | ||
|
||
# Checking probability grads: | ||
v_a = jax.random.normal(rngs[2], shape=self.x.shape) | ||
v_a -= jnp.mean(v_a, axis=-1, keepdims=True) | ||
v_a = (v_a / jnp.linalg.norm(v_a, axis=-1, keepdims=True)) * eps | ||
expected = solver(self.x, self.y, self.a + v_a, | ||
self.b) - solver(self.x, self.y, self.a - v_a, self.b) | ||
actual = 2.0 * jnp.vdot(v_a, grad_a) | ||
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) | ||
|
||
v_b = jax.random.normal(rngs[3], shape=self.y.shape) | ||
v_b -= jnp.mean(v_b, axis=-1, keepdims=True) | ||
v_b = (v_b / jnp.linalg.norm(v_b, axis=-1, keepdims=True)) * eps | ||
expected = solver(self.x, self.y, self.a, self.b + | ||
v_b) - solver(self.x, self.y, self.a, self.b - v_b) | ||
actual = 2.0 * jnp.vdot(v_b, grad_b) | ||
np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters