Skip to content

Commit

Permalink
Re-add jit argument (#221)
Browse files Browse the repository at this point in the history
* Re-add `jit` argument

* [ci skip] Add `environment.yml` for `binder`

* Fix missing `static_argnums`

* Adjust test to have `jit=False`

* Fix tests

* Fix `typing_extensions` in tests

* Fix linter
  • Loading branch information
michalk8 authored Feb 3, 2023
1 parent 1cb1f9b commit 4cebd07
Show file tree
Hide file tree
Showing 16 changed files with 100 additions and 100 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v2.0.0
rev: v2.0.1
hooks:
- id: autoflake
args:
Expand All @@ -71,4 +71,4 @@ repos:
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py3-plus, --py37-plus, --keep-runtime-typing]
args: [--py38-plus, --keep-runtime-typing]
1 change: 1 addition & 0 deletions docs/tutorials/notebooks/gromov_wasserstein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16152,6 +16152,7 @@
") # IDs are ordered from center to outer part\n",
"plt.colorbar()\n",
"\n",
"\n",
"# Initialization function\n",
"def init():\n",
" im.set_data(np.zeros(transport.shape))\n",
Expand Down
6 changes: 3 additions & 3 deletions src/ott/problems/linear/barycenter_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
return self._add_slice_for_debiased(y, b)
return y, b

@staticmethod
def _add_slice_for_debiased(
self, y: jnp.ndarray, b: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
y, b = self._y, self._b
y: jnp.ndarray, b: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
_, n, ndim = y.shape # (num_measures, max_measure_size, ndim)
# yapf: disable
y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0)
Expand Down
1 change: 0 additions & 1 deletion src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
):
assert fused_penalty > 0, fused_penalty
self._geom_xx = geom_xx._set_scale_cost(scale_cost)
self._geom_yy = geom_yy._set_scale_cost(scale_cost)
self._geom_xy = (
Expand Down
3 changes: 2 additions & 1 deletion src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def __call__(
rng: int = 0
) -> BarycenterState:
# TODO(michalk8): no reason for iterations to be outside this class
return iterations(self, bar_size, bar_prob, x_init, rng)
run_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations
return run_fn(self, bar_size, bar_prob, x_init, rng)

def init_state(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ class Sinkhorn:
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
jit: Whether to jit the iteration loop.
initializer: how to compute the initial potentials/scalings.
kwargs_init: keyword arguments when creating the initializer.
"""
Expand All @@ -697,6 +698,7 @@ def __init__(
parallel_dual_updates: bool = False,
recenter_potentials: bool = False,
use_danskin: Optional[bool] = None,
jit: bool = True,
implicit_diff: Optional[implicit_lib.ImplicitDiff
] = implicit_lib.ImplicitDiff(), # noqa: E124
initializer: Union[Literal["default", "gaussian", "sorting"],
Expand All @@ -711,6 +713,7 @@ def __init__(
self._norm_error = norm_error
self.anderson = anderson
self.implicit_diff = implicit_diff
self.jit = jit

if momentum is not None:
self.momentum = acceleration.Momentum(
Expand Down Expand Up @@ -767,7 +770,8 @@ def __call__(
init_dual_a, init_dual_b = initializer(
ot_prob, *init, lse_mode=self.lse_mode
)
return run(ot_prob, self, (init_dual_a, init_dual_b))
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, (init_dual_a, init_dual_b))

def lse_step(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState,
Expand Down
3 changes: 2 additions & 1 deletion src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def __call__(
assert ot_prob.is_balanced, "Unbalanced case is not implemented."
initializer = self.create_initializer(ot_prob)
init = initializer(ot_prob, *init, key=key, **kwargs)
return run(ot_prob, self, init)
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, init)

def _lr_costs(
self,
Expand Down
3 changes: 2 additions & 1 deletion src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def __call__(
initializer = self.create_initializer(prob)
init = initializer(prob, epsilon=self.epsilon, key=key1, **kwargs)

out = iterations(self, prob, init, key2)
run_fn = jax.jit(iterations) if self.jit else iterations
out = run_fn(self, prob, init, key2)
# TODO(lpapaxanthos): remove stop_gradient when using backprop
if self.is_low_rank:
linearization = prob.update_lr_linearization(
Expand Down
11 changes: 8 additions & 3 deletions src/ott/solvers/quadratic/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver):
min_iterations: Minimum number of iterations.
max_iterations: Maximum number of outermost iterations.
threshold: Convergence threshold.
jit: Whether to jit the iteration loop.
store_inner_errors: Whether to store the errors of the GW solver, as well
as its linear solver, at each iteration for each measure.
quad_solver: The GW solver.
Expand All @@ -66,6 +67,7 @@ def __init__(
min_iterations: int = 5,
max_iterations: int = 50,
threshold: float = 1e-3,
jit: bool = True,
store_inner_errors: bool = False,
quad_solver: Optional[gromov_wasserstein.GromovWasserstein] = None,
# TODO(michalk8): maintain the API compatibility with `was_solver`
Expand All @@ -79,14 +81,16 @@ def __init__(
min_iterations=min_iterations,
max_iterations=max_iterations,
threshold=threshold,
store_inner_errors=store_inner_errors
store_inner_errors=store_inner_errors,
jit=jit,
)
self._quad_solver = quad_solver
if quad_solver is None:
kwargs["epsilon"] = epsilon
# TODO(michalk8): store only GW errors?
kwargs["store_inner_errors"] = store_inner_errors
self._quad_solver = gromov_wasserstein.GromovWasserstein(**kwargs)
else:
self._quad_solver = quad_solver

def __call__(
self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int,
Expand All @@ -103,7 +107,8 @@ def __call__(
The solution.
"""
state = self.init_state(problem, bar_size, **kwargs)
state = iterations(solver=self, problem=problem, init_state=state)
run_fn = jax.jit(iterations) if self.jit else iterations
state = run_fn(self, problem, state)
return self.output_from_state(state)

def init_state(
Expand Down
3 changes: 3 additions & 0 deletions src/ott/solvers/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
min_iterations: int = 5,
max_iterations: int = 50,
threshold: float = 1e-3,
jit: bool = True,
store_inner_errors: bool = False,
**kwargs: Any,
):
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
self.min_iterations = min_iterations
self.max_iterations = max_iterations
self.threshold = threshold
self.jit = jit
self.store_inner_errors = store_inner_errors
self._kwargs = kwargs

Expand All @@ -86,6 +88,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
"min_iterations": self.min_iterations,
"max_iterations": self.max_iterations,
"rank": self.rank,
"jit": self.jit,
"store_inner_errors": self.store_inner_errors,
**self._kwargs
})
Expand Down
3 changes: 1 addition & 2 deletions tests/geometry/graph_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import time
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Literal, Optional, Tuple, Union

import networkx as nx
import pytest
from networkx.algorithms import shortest_paths
from networkx.generators import balanced_tree, random_graphs
from typing_extensions import Literal

import jax
import jax.experimental.sparse as jesp
Expand Down
7 changes: 5 additions & 2 deletions tests/solvers/linear/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def test_autograd_sinkhorn(
def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float:
geom = pointcloud.PointCloud(x, y, epsilon=1e-1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode)
# TODO: fails with `jit=True`, investigate
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, jit=False)
return solver(prob).reg_ot_cost

reg_ot_and_grad = jax.jit(jax.grad(reg_ot))
Expand Down Expand Up @@ -277,6 +278,8 @@ def loss_fn(x: jnp.ndarray,
lse_mode=lse_mode,
min_iterations=min_iter,
max_iterations=max_iter,
# TODO(cuturi): figure out why implicit diff breaks when `jit=True`
jit=False,
implicit_diff=implicit_diff,
)
out = solver(prob)
Expand All @@ -287,7 +290,7 @@ def loss_fn(x: jnp.ndarray,
eps = 1e-5 # perturbation magnitude

# first calculation of gradient
loss_and_grad = jax.value_and_grad(loss_fn, has_aux=True)
loss_and_grad = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))
(loss_value, out), grad_loss = loss_and_grad(x, y)
custom_grad = jnp.sum(delta * grad_loss)

Expand Down
3 changes: 1 addition & 2 deletions tests/solvers/nn/neuraldual_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for implementation of ICNN-based Kantorovich dual by Makkuva+(2020)."""
from typing import Iterator, Sequence, Tuple
from typing import Iterator, Literal, Sequence, Tuple

import pytest
from typing_extensions import Literal

import jax
import jax.numpy as jnp
Expand Down
79 changes: 0 additions & 79 deletions tests/solvers/quadratic/fgw_barycenter_test.py

This file was deleted.

64 changes: 64 additions & 0 deletions tests/solvers/quadratic/gw_barycenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,67 @@ def test_gw_barycenter(
assert out_pc.cost.shape == (bar_size, bar_size)
np.testing.assert_allclose(out_pc.cost, out_cost.cost, rtol=tol, atol=tol)
np.testing.assert_allclose(out_pc.costs, out_cost.costs, rtol=tol, atol=tol)

@pytest.mark.fast(
"jit,fused_penalty,scale_cost", [(False, 1.5, "mean"),
(True, 3.1, "max_cost")],
only_fast=0
)
def test_fgw_barycenter(
self,
rng: jnp.ndarray,
jit: bool,
fused_penalty: float,
scale_cost: str,
):

def barycenter(
y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...]
) -> gwb_solver.GWBarycenterState:
prob = gwb.GWBarycenterProblem(
y=y,
y_fused=y_fused,
num_per_segment=num_per_segment,
fused_penalty=fused_penalty,
scale_cost=scale_cost,
)
assert prob.is_fused
assert prob.fused_penalty == fused_penalty
assert not prob._y_as_costs
assert prob.max_measure_size == max(num_per_segment)
assert prob.num_measures == len(num_per_segment)
assert prob.ndim == self.ndim
assert prob.ndim_fused == self.ndim_f

solver = gwb_solver.GromovWassersteinBarycenter(
store_inner_errors=True, epsilon=epsilon
)

x_init = jax.random.normal(rng, (bar_size, self.ndim_f))
cost_init = pointcloud.PointCloud(x_init).cost_matrix

return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init))

bar_size, epsilon, = 10, 1e-1
num_per_segment = (7, 12)

key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1)
y = jnp.concatenate([
self.random_pc(n, d=self.ndim, rng=rng).x
for n, rng in zip(num_per_segment, rngs)
])
rngs = jax.random.split(key1, len(num_per_segment))
y_fused = jnp.concatenate([
self.random_pc(n, d=self.ndim_f, rng=rng).x
for n, rng in zip(num_per_segment, rngs)
])

fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter
out = fn(y, y_fused, num_per_segment)

assert out.cost.shape == (bar_size, bar_size)
assert out.x.shape == (bar_size, self.ndim_f)
np.testing.assert_array_equal(jnp.isfinite(out.cost), True)
np.testing.assert_array_equal(jnp.isfinite(out.x), True)
np.testing.assert_array_equal(jnp.isfinite(out.costs), True)
np.testing.assert_array_equal(jnp.isfinite(out.errors), True)
3 changes: 1 addition & 2 deletions tests/tools/k_means_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import sys
from typing import Any, Optional, Tuple, Union
from typing import Any, Literal, Optional, Tuple, Union

import pytest
from typing_extensions import Literal

import jax
import jax.numpy as jnp
Expand Down

0 comments on commit 4cebd07

Please sign in to comment.