Skip to content

Commit

Permalink
Add GPU test runner (#242)
Browse files Browse the repository at this point in the history
* Add `rstcheck` and `doc8`

* Pass `CUDA_VISIBLE_DEVICES` to `tox`

* Add GPU CI runner

* Fix Python version for GPU tests

* Try running without `tox`

* Fix not installing jax[cuda]

* Use different Docker image

* Fix escpape

* Use apt-get

* Do not use `{}`

* Fix not installing `git`

* Use personal Docker image

* Pin `jax[cuda]` version

* Mark grad(sqrtm) as CPU only test

* Fix ICNN hessian test on GPU

* Use `eigvalsh` to check for positive-semidefinite

* Adjust tolerance in a test

* Mark Sinkhorn online as CPU

* Run all tests on GPU

* Skip more tests on GPU

* Update tolerances on k-means test

* Always jit in online Sinkhorn test

* Use simple comparison

* Only run fast GPU tests, try other GPU

* Use previous GPU

* [ci skip] Fix test
  • Loading branch information
michalk8 authored Feb 8, 2023
1 parent 4fad710 commit 4caafeb
Show file tree
Hide file tree
Showing 28 changed files with 183 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end_of_line = lf
insert_final_newline = true
charset = utf-8

[*py]
[{*py,*.rst}]
indent_size = 2
indent_style = space
max_line_length = 80
Expand Down
25 changes: 24 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ jobs:
run: |
tox -e py39 --skip-pkg-install -- -m fast --memray -n auto -vv
gpu-tests:
name: Fast GPU tests Python 3.8 on ubuntu-20.04
runs-on: [self-hosted, ott-gpu]
container:
image: docker://michalk8/cuda:11.3.0-ubuntu20.04
options: --gpus="device=12"
steps:
- uses: actions/checkout@v3
- name: Install dependencies
# `jax[cuda]<0.4` because of: https://github.com/google/jax/issues/13758
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -e".[test]"
python3 -m pip install "jax[cuda]<0.4" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- name: Nvidia SMI
run: |
nvidia-smi
- name: Run tests
run: |
python3 -m pytest -m "fast and not cpu" --memray --durations 10 -vv
tests:
name: Python ${{ matrix.python-version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -68,7 +91,7 @@ jobs:
run: |
tox -e py${{ matrix.python-version }} --skip-pkg-install
env:
PYTEST_ADDOPTS: --memray --durations 10 -vv
PYTEST_ADDOPTS: --memray -vv

- name: Upload coverage
uses: codecov/codecov-action@v3
Expand Down
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,14 @@ repos:
hooks:
- id: pyupgrade
args: [--py38-plus, --keep-runtime-typing]
- repo: https://github.com/rstcheck/rstcheck
rev: v6.1.1
hooks:
- id: rstcheck
additional_dependencies: [tomli]
args: [--config=pyproject.toml]
- repo: https://github.com/PyCQA/doc8
rev: v1.1.1
hooks:
- id: doc8
args: [--config=pyproject.toml]
49 changes: 26 additions & 23 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
.. _geometry:

ott.geometry package
====================
ott.geometry
============
.. currentmodule:: ott.geometry
.. automodule:: ott.geometry

This package implements several classes to define a geometry, arguably the most influential
ingredient of optimal transport problem. In its full generality, a :class:`~ott.geometry.geometry.Geometry`
defines source points (input measure), target points (target measure) and a ground cost function
(resp. a positive kernel function) that quantifies how expensive (resp. easy) it is to displace
a unit of mass from any of the input points to the target points.
This package implements several classes to define a geometry, arguably the most
influential ingredient of optimal transport problem. In its full generality, a
:class:`~ott.geometry.geometry.Geometry` defines source points (input measure),
target points (target measure) and a ground cost function (resp. a positive
kernel function) that quantifies how expensive (resp. easy) it is to displace a
unit of mass from any of the input points to the target points.

The geometry package proposes a few simple geometries. The simplest of all would
be that for which input and target points coincide, and the geometry between them
simplifies to a symmetric cost or kernel matrix. In the very particular case
where these points happen to lie on grid (a cartesian product in full generality,
e.g. 2 or 3D grids), the :class:`~ott.geometry.grid.Grid` geometry will prove useful.
be that for which input and target points coincide, and the geometry between
them simplifies to a symmetric cost or kernel matrix. In the very particular
case where these points happen to lie on grid (a cartesian product in full
generality, e.g. 2 or 3D grids), the :class:`~ott.geometry.grid.Grid`
geometry will prove useful.

For more general settings where input/target points do not coincide, one can
alternatively instantiate a :class:`~ott.geometry.geometry.Geometry` through a rectangular cost matrix.
alternatively instantiate a :class:`~ott.geometry.geometry.Geometry` through a
rectangular cost matrix.

However, it is often preferable in applications to define ground costs "symbolically",
by listing instead points in the input/target point clouds, to specify directly
a cost *function* between them. Such functions should follow the :class:`~ott.geometry.costs.CostFn`
class description. We provide a few standard cost functions that are meaningful in an
OT context, notably the (unbalanced, regularized) Bures distances between
Gaussians :cite:`janati:20`. That cost can be used for instance to compute a distance between
Gaussian mixtures, as proposed in :cite:`chen:19a` and revisited in :cite:`delon:20`.
However, it is often preferable in applications to define ground costs
"symbolically", by listing instead points in the input/target point clouds, to
specify directly a cost *function* between them. Such functions should follow
the :class:`~ott.geometry.costs.CostFn` class description. We provide a few
standard cost functions that are meaningful in an OT context, notably the
(unbalanced, regularized) Bures distances between Gaussians :cite:`janati:20`.
That cost can be used for instance to compute a distance between Gaussian
mixtures, as proposed in :cite:`chen:19a` and revisited in :cite:`delon:20`.

To be useful with Sinkhorn solvers, ``Geometries`` typically need to provide an
``epsilon`` regularization parameter. We propose either to set that value once for
all, or implement an annealing :class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler.
``epsilon`` regularization parameter. We propose either to set that value once
for all, or implement an annealing
:class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler.

Geometries
----------
Expand Down
64 changes: 35 additions & 29 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ Optimal Transport Tools (OTT)

Introduction
------------
``OTT`` is a `JAX <https://jax.readthedocs.io/en/latest/>`_ package that bundles a few utilities to compute,
and differentiate as needed, the solution to optimal transport (OT) problems, taken in a fairly wide sense.
For instance, ``OTT`` can of course compute Wasserstein (or Gromov-Wasserstein) distances between
weighted clouds of points (or histograms) in a wide variety of scenarios,
but also estimate Monge maps, Wasserstein barycenters, and help with simpler tasks
such as differentiable approximations to ranking or even clustering.
``OTT`` is a `JAX <https://jax.readthedocs.io/en/latest/>`_ package that bundles
a few utilities to compute, and differentiate as needed, the solution to optimal
transport (OT) problems, taken in a fairly wide sense. For instance, ``OTT`` can
of course compute Wasserstein (or Gromov-Wasserstein) distances between weighted
clouds of points (or histograms) in a wide variety of scenarios, but also
estimate Monge maps, Wasserstein barycenters, and help with simpler tasks such
as differentiable approximations to ranking or even clustering.

To achieve this, ``OTT`` rests on two families of tools:
The first family consists in *discrete* solvers computing transport between point clouds,
using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn :cite:`scetbon:21` algorithms,
and moving up towards Gromov-Wasserstein :cite:`memoli:11,peyre:16`;
the second family consists in *continuous* solvers, using suitable neural architectures :cite:`amos:17` coupled
with SGD type estimators :cite:`makkuva:20,korotin:21`.

- the first family consists in *discrete* solvers computing transport between
point clouds, using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn
:cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein
:cite:`memoli:11,peyre:16`;
- the second family consists in *continuous* solvers, using suitable neural
architectures :cite:`amos:17` coupled with SGD type estimators
:cite:`makkuva:20,korotin:21`.

Installation
------------
Expand All @@ -27,7 +31,7 @@ Install ``OTT`` from `PyPI <https://pypi.org/project/ott-jax/>`_ as:
pip install ott-jax
or with ``conda`` via `conda-forge <https://anaconda.org/conda-forge/ott-jax>`_ as:
or with ``conda`` via `conda-forge`_ as:

.. code-block:: bash
Expand All @@ -37,40 +41,41 @@ Design Choices
--------------
``OTT`` is designed with the following choices:

- Take advantage whenever possible of JAX features, such as `Just-in-time (JIT) compilation`_,
`auto-vectorization (VMAP)`_ and both `automatic`_ but most importantly `implicit`_ differentiation.
- Take advantage whenever possible of JAX features, such as
`Just-in-time (JIT) compilation`_, `auto-vectorization (VMAP)`_ and both
`automatic`_ but most importantly `implicit`_ differentiation.
- Split geometry from OT solvers in the discrete case: We argue that there
should be one, and one implementation only, of every major OT algorithm
(Sinkhorn, Gromov-Wasserstein, barycenters, etc...), regardless of the
geometric setup that is considered. To give a concrete example, any
speedups one may benefit from by using a specific cost
(e.g. Sinkhorn being faster when run on a separable cost on histograms supported
on a separable grid :cite:`solomon:15`) should not require a separate
reimplementation of a Sinkhorn routine.
speedups one may benefit from by using a specific cost (e.g. Sinkhorn being
faster when run on a separable cost on histograms supported on a separable
grid :cite:`solomon:15`) should not require a separate reimplementation
of a Sinkhorn routine.
- As a consequence, and to minimize code copy/pasting, use as often as possible
object hierarchies, and interleave outer solvers (such as quadratic,
aka Gromov-Wasserstein solvers) with inner solvers (e.g. Low-Rank Sinkhorn).
This choice ensures that speedups achieved at lower computation levels
(e.g. low-rank factorization of squared Euclidean distances) propagate seamlessly and
automatically in higher level calls (e.g. updates in Gromov-Wasserstein),
without requiring any attention from the user.
(e.g. low-rank factorization of squared Euclidean distances) propagate
seamlessly and automatically in higher level calls (e.g. updates in
Gromov-Wasserstein), without requiring any attention from the user.

.. TODO(marcocuturi): add missing package descriptions below
Packages
--------
- :ref:`geometry` contains classes to instantiate objects that describe
- :doc:`geometry` contains classes to instantiate objects that describe
*two point clouds* paired with a *cost* function. Geometry objects are used to
describe OT problems, handled by solvers in the :ref:`solvers`.
- :ref:`problems`
- :ref:`solvers`
- :ref:`initializers`
- :ref:`tools` provides an interface to exploit OT solutions, as produced by
solvers in the :ref:`solvers`. Such tasks include computing approximations
describe OT problems, handled by solvers in the solvers.
- :doc:`problems/index`
- :doc:`solvers/index`
- :doc:`initializers/index`
- :doc:`tools` provides an interface to exploit OT solutions, as produced by
solvers in the solvers. Such tasks include computing approximations
to Wasserstein distances :cite:`genevay:18,sejourne:19`, approximating OT
between GMMs, or computing differentiable sort and quantile operations
:cite:`cuturi:19`.
- :ref:`math`
- :doc:`math`

.. toctree::
:maxdepth: 1
Expand Down Expand Up @@ -116,3 +121,4 @@ Packages
.. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
.. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
.. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp
.. _conda-forge: https://anaconda.org/conda-forge/ott-jax
6 changes: 2 additions & 4 deletions docs/initializers/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _initializers:

ott.initializers package
========================
ott.initializers
================

.. TODO(cuturi): add some nice text here please
Expand Down
4 changes: 2 additions & 2 deletions docs/initializers/linear.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.initializers.linear package
===============================
ott.initializers.linear
=======================
.. currentmodule:: ott.initializers.linear
.. automodule:: ott.initializers.linear

Expand Down
4 changes: 2 additions & 2 deletions docs/initializers/nn.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.initializers.nn package
===========================
ott.initializers.nn
===================
.. currentmodule:: ott.initializers.nn
.. automodule:: ott.initializers.nn

Expand Down
4 changes: 2 additions & 2 deletions docs/initializers/quadratic.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.initializers.quadratic package
==================================
ott.initializers.quadratic
==========================
.. currentmodule:: ott.initializers.quadratic
.. automodule:: ott.initializers.quadratic

Expand Down
6 changes: 2 additions & 4 deletions docs/math.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _math:

ott.math package
================
ott.math
========
.. currentmodule:: ott.math
.. automodule:: ott.math

Expand Down
6 changes: 2 additions & 4 deletions docs/problems/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _problems:

ott.problems package
====================
ott.problems
============

.. TODO(marcocuturi): add some nice text here please
Expand Down
4 changes: 2 additions & 2 deletions docs/problems/linear.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.problems.linear package
===========================
ott.problems.linear
===================
.. currentmodule:: ott.problems.linear
.. automodule:: ott.problems.linear

Expand Down
4 changes: 2 additions & 2 deletions docs/problems/quadratic.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.problems.quadratic package
==============================
ott.problems.quadratic
======================
.. currentmodule:: ott.problems.quadratic
.. automodule:: ott.problems.quadratic

Expand Down
6 changes: 2 additions & 4 deletions docs/solvers/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
.. _solvers:

ott.solvers package
===================
ott.solvers
===========

.. TODO(marcocuturi): add some nice text here please
Expand Down
4 changes: 2 additions & 2 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.solvers.linear package
==========================
ott.solvers.linear
==================
.. currentmodule:: ott.solvers.linear
.. automodule:: ott.solvers.linear

Expand Down
4 changes: 2 additions & 2 deletions docs/solvers/nn.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.solvers.nn package
======================
ott.solvers.nn
==============
.. currentmodule:: ott.solvers.nn
.. automodule:: ott.solvers.nn

Expand Down
4 changes: 2 additions & 2 deletions docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ott.solvers.quadratic package
=============================
ott.solvers.quadratic
=====================
.. currentmodule:: ott.solvers.quadratic
.. automodule:: ott.solvers.quadratic

Expand Down
13 changes: 6 additions & 7 deletions docs/tools.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
.. _tools:

ott.tools package
=================
ott.tools
=========
.. currentmodule:: ott.tools
.. automodule:: ott.tools

The tools package contains high level functions that build on outputs produced by core functions.
They can be used to compute Sinkhorn divergences :cite:`sejourne:19`, instantiate transport matrices,
provide differentiable approximations to ranks and quantile functions :cite:`cuturi:19`, etc.
The tools package contains high level functions that build on outputs produced
by core functions. They can be used to compute Sinkhorn divergences
:cite:`sejourne:19`, instantiate transport matrices, provide differentiable
approximations to ranks and quantile functions :cite:`cuturi:19`, etc.

Segmented Sinkhorn
------------------
Expand Down
Loading

0 comments on commit 4caafeb

Please sign in to comment.