Skip to content

Commit

Permalink
Remove jit from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Mar 15, 2023
1 parent b84f084 commit 27bf935
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tests/geometry/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool):

def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
solver = sinkhorn.Sinkhorn(lse_mode=False, jit=False)
solver = sinkhorn.Sinkhorn(lse_mode=False)
problem = linear_problem.LinearProblem(geom)
return solver(problem)

Expand Down
5 changes: 1 addition & 4 deletions tests/solvers/linear/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def test_autograd_sinkhorn(
def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float:
geom = pointcloud.PointCloud(x, y, epsilon=1e-1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
# TODO: fails with `jit=True`, investigate
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, jit=False)
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode)
return solver(prob).reg_ot_cost

reg_ot_and_grad = jax.jit(jax.grad(reg_ot))
Expand Down Expand Up @@ -275,8 +274,6 @@ def loss_fn(x: jnp.ndarray,
lse_mode=lse_mode,
min_iterations=min_iter,
max_iterations=max_iter,
# TODO(cuturi): figure out why implicit diff breaks when `jit=True`
jit=False,
implicit_diff=implicit_diff,
)
out = solver(prob)
Expand Down
3 changes: 1 addition & 2 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,7 @@ def test_sinkhorn_online_memory_jit(self, batch_size: int):
y = jax.random.uniform(rngs[1], (m, 2))
geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1)
problem = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn(jit=False)
solver = jax.jit(solver)
solver = jax.jit(sinkhorn.Sinkhorn())

out = solver(problem)
assert out.converged
Expand Down

0 comments on commit 27bf935

Please sign in to comment.