Skip to content

Commit

Permalink
Use more iterations for LR-GW (#494)
Browse files Browse the repository at this point in the history
* Use more iterations for LR-GW

* Fix tests
  • Loading branch information
michalk8 authored Feb 24, 2024
1 parent 7ef2235 commit 81fe3be
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
9 changes: 8 additions & 1 deletion src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class LRGromovWasserstein(sinkhorn.Sinkhorn):
lse_mode: Whether to run computations in LSE or kernel mode.
inner_iterations: Number of inner iterations used by the algorithm before
re-evaluating progress.
min_iterations: The minimum number of low-rank Sinkhorn iterations carried
out before the error is computed and monitored.
max_iterations: The maximum number of low-rank Sinkhorn iterations.
use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t.
input parameters. Only `True` handled at this moment.
implicit_diff: Whether to use implicit differentiation. Currently, only
Expand Down Expand Up @@ -305,9 +308,11 @@ def __init__(
"generalized-k-means"],
initializers_lr.LRInitializer] = "random",
lse_mode: bool = True,
inner_iterations: int = 10,
use_danskin: bool = True,
implicit_diff: bool = False,
inner_iterations: int = 2_000,
min_iterations: int = 10_000,
max_iterations: int = 100_000,
kwargs_dys: Optional[Mapping[str, Any]] = None,
kwargs_init: Optional[Mapping[str, Any]] = None,
progress_fn: Optional[ProgressCallbackFn_t] = None,
Expand All @@ -317,6 +322,8 @@ def __init__(
super().__init__(
lse_mode=lse_mode,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
**kwargs
Expand Down
12 changes: 11 additions & 1 deletion tests/initializers/quadratic/gw_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def test_explicit_initializer_lr(self):
rank = 10
q_init = initializers_lr.Rank2Initializer(rank)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank, initializer=q_init
rank=rank,
initializer=q_init,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)

assert solver.create_initializer("not used") is q_init
Expand All @@ -67,11 +71,17 @@ def test_gw_better_initialization_helps(self, rng: jax.Array, eps: float):
rank=rank,
initializer="random",
epsilon=eps,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
solver_kmeans = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer="k-means",
epsilon=eps,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)

out_random = solver_random(problem)
Expand Down
12 changes: 10 additions & 2 deletions tests/solvers/quadratic/fgw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def test_fgw_lr_memory(self, rng: jax.Array, jit: bool):
geom_xy = pointcloud.PointCloud(xx, yy)
prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, geom_xy)

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=2)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=2, min_iterations=0, inner_iterations=10, max_iterations=2000
)
if jit:
solver = jax.jit(solver)

Expand Down Expand Up @@ -262,7 +264,13 @@ def test_fgw_lr_generic_cost_matrix(
lr_prob = prob.to_low_rank()
assert lr_prob.is_low_rank

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=5, epsilon=10.0)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=5,
epsilon=10.0,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
out = solver(prob)

assert solver.rank == 5
Expand Down
42 changes: 35 additions & 7 deletions tests/solvers/quadratic/gw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def test_gw_pointcloud(self, balanced: bool, rank: int):
)
if rank > 0:
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank, epsilon=0.0, max_iterations=10
rank=rank,
epsilon=0.0,
max_iterations=10,
)
else:
solver = gromov_wasserstein.GromovWasserstein(
Expand Down Expand Up @@ -324,7 +326,13 @@ def test_gw_lr(self, rng: jax.Array):
geom_yy = pointcloud.PointCloud(y)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=a, b=b)

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=5, epsilon=0.2)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=5,
epsilon=0.2,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
ot_gwlr = solver(prob)
solver = gromov_wasserstein.GromovWasserstein(epsilon=0.2)
ot_gw = solver(prob)
Expand Down Expand Up @@ -352,9 +360,17 @@ def test_gw_lr_matches_fused(self, rng: jax.Array):
geom_xx, geom_yy, geom_xy=geom_xy, fused_penalty=1.3, a=a, b=b
)

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=6, min_iterations=0, inner_iterations=10, max_iterations=2000
)
ot_gwlr = solver(prob)
solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6, epsilon=1e-1)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=6,
epsilon=1e-1,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
ot_gwlreps = solver(prob)
solver = gromov_wasserstein.GromovWasserstein(epsilon=5e-2)
ot_gw = solver(prob)
Expand All @@ -372,7 +388,13 @@ def test_gw_lr_apply(self, axis: int):
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, a=self.a, b=self.b
)
solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=2, epsilon=1e-1)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=2,
epsilon=1e-1,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
out = solver(prob)

arr, matrix = (self.x, out.matrix) if axis == 0 else (self.y, out.matrix.T)
Expand Down Expand Up @@ -430,7 +452,12 @@ def test_gwlr_unbalanced(
)
solver = jax.jit(
gromov_wasserstein_lr.LRGromovWasserstein(
rank=4, epsilon=eps, kwargs_dys={"translation_invariant": ti}
rank=4,
epsilon=eps,
kwargs_dys={"translation_invariant": ti},
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
)

Expand Down Expand Up @@ -468,8 +495,9 @@ def test_gwlr_unbalanced_matches_balanced(
rank=rank,
epsilon=eps,
initializer="random",
inner_iterations=50,
min_iterations=50,
max_iterations=50
max_iterations=50,
)
)

Expand Down

0 comments on commit 81fe3be

Please sign in to comment.