diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 80a7fb8ee..0c71fc443 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -207,7 +207,7 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray: def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool): def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: - solver = sinkhorn.Sinkhorn(lse_mode=False, jit=False) + solver = sinkhorn.Sinkhorn(lse_mode=False) problem = linear_problem.LinearProblem(geom) return solver(problem) diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 8512dcdae..6a0fdf6dc 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -157,8 +157,7 @@ def test_autograd_sinkhorn( def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: geom = pointcloud.PointCloud(x, y, epsilon=1e-1) prob = linear_problem.LinearProblem(geom, a=a, b=b) - # TODO: fails with `jit=True`, investigate - solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, jit=False) + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode) return solver(prob).reg_ot_cost reg_ot_and_grad = jax.jit(jax.grad(reg_ot)) @@ -275,8 +274,6 @@ def loss_fn(x: jnp.ndarray, lse_mode=lse_mode, min_iterations=min_iter, max_iterations=max_iter, - # TODO(cuturi): figure out why implicit diff breaks when `jit=True` - jit=False, implicit_diff=implicit_diff, ) out = solver(prob) diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index bf0ff23d7..926f48442 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -465,8 +465,7 @@ def test_sinkhorn_online_memory_jit(self, batch_size: int): y = jax.random.uniform(rngs[1], (m, 2)) geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1) problem = linear_problem.LinearProblem(geom) - solver = sinkhorn.Sinkhorn(jit=False) - solver = jax.jit(solver) + solver = jax.jit(sinkhorn.Sinkhorn()) out = solver(problem) assert out.converged