Skip to content

Commit

Permalink
Fix/io callback (#506)
Browse files Browse the repository at this point in the history
* Use `jax.debug.callback`

* Update tutorial
  • Loading branch information
michalk8 authored Mar 26, 2024
1 parent 9eed305 commit 14d4b81
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 26 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/tracking_progress.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"\n",
"{mod}`ott` offers a simple and flexible mechanism that works well with {func}`~jax.jit`, and applies to both the functional interface and the class interface.\n",
"\n",
"The solvers {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`, {class}`low-rank Sinkhorn <ott.solvers.linear.sinkhorn_lr.LRSinkhorn>`, and {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` only report progress if we pass a callback function with a specific signature. The callback is then called at each iteration using {func}`~jax.experimental.io_callback`."
"The solvers {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`, {class}`low-rank Sinkhorn <ott.solvers.linear.sinkhorn_lr.LRSinkhorn>`, and {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` only report progress if we pass a callback function with a specific signature. The callback is then called at each iteration using {func}`~jax.debug.callback`."
]
},
{
Expand Down Expand Up @@ -522,7 +522,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
5 changes: 2 additions & 3 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)

import jax
import jax.experimental
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
Expand Down Expand Up @@ -1008,8 +1007,8 @@ def one_iteration(
state = state.set(errors=errors)

if self.progress_fn is not None:
jax.experimental.io_callback(
self.progress_fn, None,
jax.debug.callback(
self.progress_fn,
(iteration, self.inner_iterations, self.max_iterations, state)
)
return state
Expand Down
5 changes: 2 additions & 3 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)

import jax
import jax.experimental
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
Expand Down Expand Up @@ -703,8 +702,8 @@ def one_iteration(
)

if self.progress_fn is not None:
jax.experimental.io_callback(
self.progress_fn, None,
jax.debug.callback(
self.progress_fn,
(iteration, self.inner_iterations, self.max_iterations, state)
)

Expand Down
5 changes: 2 additions & 3 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)

import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -405,8 +404,8 @@ def body_fn(
# Inner iterations is currently fixed to 1.
inner_iterations = 1
if solver.progress_fn is not None:
jax.experimental.io_callback(
solver.progress_fn, None,
jax.debug.callback(
solver.progress_fn,
(iteration, inner_iterations, solver.max_iterations, state)
)

Expand Down
5 changes: 2 additions & 3 deletions src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)

import jax
import jax.experimental
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
Expand Down Expand Up @@ -734,8 +733,8 @@ def one_iteration(
)

if self.progress_fn is not None:
jax.experimental.io_callback(
self.progress_fn, None,
jax.debug.callback(
self.progress_fn,
(iteration, self.inner_iterations, self.max_iterations, state)
)

Expand Down
4 changes: 0 additions & 4 deletions tests/solvers/linear/sinkhorn_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ def test_output_apply_batch_size(self, axis: int):
)

@pytest.mark.fast()
@pytest.mark.skipif(
jax.__version_info__ < (0, 4, 0),
reason="`jax.experimental.io_callback` doesn't exist"
)
def test_progress_fn(self):
"""Check that the callback function is actually called."""
num_iterations = 37
Expand Down
8 changes: 0 additions & 8 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,6 @@ def test_f_potential_is_zero_centered(self, lse_mode: bool):
@pytest.mark.parametrize(("use_tqdm", "custom_buffer"), [(False, False),
(False, True),
(True, False)])
@pytest.mark.skipif(
jax.__version_info__ < (0, 4, 0),
reason="`jax.experimental.io_callback` doesn't exist"
)
def test_progress_fn(self, capsys, use_tqdm: bool, custom_buffer: bool):
geom = pointcloud.PointCloud(self.x, self.y, epsilon=1e-1)

Expand Down Expand Up @@ -595,10 +591,6 @@ def test_progress_fn(self, capsys, use_tqdm: bool, custom_buffer: bool):
assert captured.out.startswith("foo 7/14"), captured

@pytest.mark.fast()
@pytest.mark.skipif(
jax.__version_info__ < (0, 4, 0),
reason="`jax.experimental.io_callback` doesn't exist"
)
def test_custom_progress_fn(self):
"""Check that the callback function is actually called."""
num_iterations = 30
Expand Down
24 changes: 24 additions & 0 deletions tests/solvers/quadratic/gw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from ott import utils
from ott.geometry import geometry, low_rank, pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import implicit_differentiation as implicit_lib
Expand Down Expand Up @@ -511,3 +512,26 @@ def test_gwlr_unbalanced_matches_balanced(
np.testing.assert_allclose(
res.primal_cost, res_unbal.primal_cost, rtol=1e-3, atol=1e-3
)

@pytest.mark.parametrize("grad", [False, True])
def test_gw_progress_fn(self, grad: bool):

def callback(x: jnp.ndarray, y: jnp.ndarray):
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)

lin_solver = sinkhorn.Sinkhorn(progress_fn=utils.default_progress_fn(),)
quad_solver = gromov_wasserstein.GromovWasserstein(
linear_ot_solver=lin_solver,
progress_fn=utils.default_progress_fn(),
# needs to be explicitly set
store_inner_errors=True,
)

return quad_solver(prob).reg_gw_cost

fn = jax.grad(callback) if grad else callback
res = fn(self.x, self.y)

np.testing.assert_array_equal(jnp.isfinite(res), True)

0 comments on commit 14d4b81

Please sign in to comment.