Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

univariate dual vectors for arbitrary sizes #505

Merged
merged 47 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
eab437d
univariate dual vectors for arbitrary sizes
marcocuturi Mar 23, 2024
e194d71
Merge branch 'main' into 1d_dual
michalk8 Apr 24, 2024
2a9427e
Remove old impl.
michalk8 Apr 24, 2024
3507dd4
Simplify outputs
michalk8 Apr 24, 2024
8b8891c
Better type
michalk8 Apr 24, 2024
8751f5c
Add initial north-west solver
michalk8 Apr 24, 2024
0ba9b49
Update NW solver
michalk8 Apr 24, 2024
42ae321
Clean the solver implementation
michalk8 Apr 24, 2024
4cd2676
Remove unnecessary sorting
michalk8 Apr 24, 2024
62d5ab6
Add TODO
michalk8 Apr 24, 2024
29f7068
Remove call to `jnp.minimum`
michalk8 Apr 24, 2024
6ffb79f
Merge branch 'main' into 1d_dual
michalk8 May 28, 2024
9e662ba
Use functional interface
michalk8 May 30, 2024
fc70b5d
Remove `UnivariateSolver` class
michalk8 May 30, 2024
2462043
Update tests
michalk8 May 30, 2024
0c29cde
Use `BCOO`
michalk8 May 30, 2024
bc4b902
Add `dual_costs` property
michalk8 May 30, 2024
3a9e6a9
Fix NW not undoing the sort
michalk8 May 30, 2024
398866a
Fix `test_dual_vectors`
michalk8 May 30, 2024
de3945c
Fix last test
michalk8 May 30, 2024
95c8761
Fix links in docs
michalk8 May 30, 2024
cf1e984
Update citation
michalk8 May 31, 2024
38bc137
Update `UnivariateWasserstein`'s docs
michalk8 Jun 3, 2024
e7e1450
More grad tests
michalk8 Jun 3, 2024
7ed6d87
Polish another test
michalk8 Jun 4, 2024
e5d1dd8
Parametrize the last test
michalk8 Jun 4, 2024
55afb49
Revert back to `eps=1.5e-3`
michalk8 Jun 4, 2024
6a8af0a
Add more LB tests
michalk8 Jun 4, 2024
34b06a7
Make Sinkhorn test more stable
michalk8 Jun 4, 2024
7d3fb41
Start improving docs in `UnivariateOutput`
michalk8 Jun 4, 2024
bdb1751
Update title in the docs
michalk8 Jun 4, 2024
102e975
Continue polishing the docs
michalk8 Jun 4, 2024
fd2c112
Fix remaining TODOs
michalk8 Jun 4, 2024
335433b
Nicer impl.
michalk8 Jun 4, 2024
08d7b60
Better docs
michalk8 Jun 4, 2024
9c221a0
Last pass over the docs
michalk8 Jun 4, 2024
b0732ef
Test different costs in `test_dual_vectors`
michalk8 Jun 4, 2024
d761dad
[ci-skip] Fix typo in docs
michalk8 Jun 4, 2024
42ac525
Rename functions
michalk8 Jun 4, 2024
8064f14
Add `solve_univariate` high-level function
michalk8 Jun 4, 2024
98c4409
Polish docs
michalk8 Jun 4, 2024
5508aa3
Incorporate the last comments
michalk8 Jun 4, 2024
c2d350b
Fix spelling
michalk8 Jun 4, 2024
eb85f4d
Update NW solver docs
michalk8 Jun 4, 2024
83174de
Unify phrasing in the docs
michalk8 Jun 4, 2024
4f6c806
Update `UnivariateWasserstein` docs
michalk8 Jun 4, 2024
2bd0cdb
Remove extra `.`
michalk8 Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Cost Functions
:toctree: _autosummary

costs.CostFn
costs.TICost
costs.SqPNorm
costs.PNormP
costs.SqEuclidean
Expand Down
20 changes: 8 additions & 12 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ @article{vayer:20

@article{demetci:22,
author = {Demetci, Pinar and Santorella, Rebecca and Sandstede, Björn and Noble, William Stafford and Singh, Ritambhara},
doi = {10.1089/cmb.2021.0446},
journal = {Journal of Computational Biology},
note = {PMID: 35050714},
number = {1},
Expand Down Expand Up @@ -435,18 +436,13 @@ @article{schmitz:18
year = {2018},
}

@article{alvarez-esteban:16,
author = {Álvarez-Esteban, Pedro C. and {del Barrio}, E. and Cuesta-Albertos, J.A. and Matrán, C.},
url = {https://www.sciencedirect.com/science/article/pii/S0022247X16300907},
doi = {10.1016/j.jmaa.2016.04.045},
issn = {0022-247X},
journal = {Journal of Mathematical Analysis and Applications},
keywords = {Mass transportation problem,-Wasserstein distance,Wasserstein barycenter,Fréchet mean,Fixed-point iteration,Location-scatter families},
number = {2},
pages = {744--762},
title = {A fixed-point approach to barycenters in Wasserstein space},
volume = {441},
year = {2016},
@misc{alvarez-esteban:16,
author = {Álvarez-Esteban, Pedro C. and del Barrio, E. and Cuesta-Albertos, J. A. and Matrán, C.},
eprint = {1511.05355},
eprintclass = {stat.CO},
eprinttype = {arXiv},
title = {A fixed-point approach to barycenters in Wasserstein space},
year = {2016},
}

@article{lehmann:21,
Expand Down
9 changes: 6 additions & 3 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ Barycenter Solvers
discrete_barycenter.FixedBarycenter
discrete_barycenter.SinkhornBarycenterOutput

Other Solvers
-------------
Univariate Solvers
------------------
.. autosummary::
:toctree: _autosummary

univariate.UnivariateSolver
univariate.uniform_distance
univariate.quantile_distance
univariate.north_west_distance
univariate.UnivariateOutput

Sinkhorn Acceleration
---------------------
Expand Down
39 changes: 13 additions & 26 deletions src/ott/geometry/distrib_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Callable, Optional

import jax.numpy as jnp
import jax.tree_util as jtu
Expand All @@ -20,9 +20,7 @@
from ott.problems.linear import linear_problem
from ott.solvers.linear import univariate

__all__ = [
"UnivariateWasserstein",
]
__all__ = ["UnivariateWasserstein"]


@jtu.register_pytree_node_class
Expand All @@ -34,31 +32,24 @@ class UnivariateWasserstein(costs.CostFn):
ground cost.

Args:
solve_fn: 1D optimal transport solver, e.g.,
:func:`~ott.solvers.linear.univariate.uniform_distance`.
ground_cost: Cost used to compute the 1D optimal transport between vector,
should be a translation-invariant (TI) cost for correctness.
If :obj:`None`, defaults to :class:`~ott.geometry.costs.SqEuclidean`.
solver: 1D optimal transport solver.
kwargs: Arguments passed on when calling the
:class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include
random key, or specific instructions to subsample or compute using
quantiles.
"""

def __init__(
self,
solve_fn: Callable[[linear_problem.LinearProblem],
univariate.UnivariateOutput],
ground_cost: Optional[costs.TICost] = None,
solver: Optional[univariate.UnivariateSolver] = None,
**kwargs: Any
):
super().__init__()

self.ground_cost = (
costs.SqEuclidean() if ground_cost is None else ground_cost
)
self._solver = univariate.UnivariateSolver() if solver is None else solver
self._kwargs_solve = kwargs
# ensure transport solutions are neither computed nor stored
self._kwargs_solve["return_transport"] = False
self._solve_fn = solve_fn

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist.
Expand All @@ -70,20 +61,16 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
Returns:
The transport cost.
"""
out = self._solver(
linear_problem.LinearProblem(
pointcloud.PointCloud(
x[:, None], y[:, None], cost_fn=self.ground_cost
)
), **self._kwargs_solve
geom = pointcloud.PointCloud(
x[:, None], y[:, None], cost_fn=self.ground_cost
)
prob = linear_problem.LinearProblem(geom)
out = self._solve_fn(prob)
return jnp.squeeze(out.ot_costs)

def tree_flatten(self): # noqa: D102
return (self.ground_cost,), (self._solver, self._kwargs_solve)
return (self.ground_cost,), (self._solve_fn,)

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
ground_cost, = children
solver, solve_kwargs = aux_data
return cls(ground_cost, solver, **solve_kwargs)
return cls(solve_fn=aux_data[0], ground_cost=children[0])
Loading
Loading