Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more iterations for LR-GW #494

Merged
merged 2 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class LRGromovWasserstein(sinkhorn.Sinkhorn):
lse_mode: Whether to run computations in LSE or kernel mode.
inner_iterations: Number of inner iterations used by the algorithm before
re-evaluating progress.
min_iterations: The minimum number of low-rank Sinkhorn iterations carried
out before the error is computed and monitored.
max_iterations: The maximum number of low-rank Sinkhorn iterations.
use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t.
input parameters. Only `True` handled at this moment.
implicit_diff: Whether to use implicit differentiation. Currently, only
Expand Down Expand Up @@ -305,9 +308,11 @@ def __init__(
"generalized-k-means"],
initializers_lr.LRInitializer] = "random",
lse_mode: bool = True,
inner_iterations: int = 10,
use_danskin: bool = True,
implicit_diff: bool = False,
inner_iterations: int = 2_000,
min_iterations: int = 10_000,
max_iterations: int = 100_000,
kwargs_dys: Optional[Mapping[str, Any]] = None,
kwargs_init: Optional[Mapping[str, Any]] = None,
progress_fn: Optional[ProgressCallbackFn_t] = None,
Expand All @@ -317,6 +322,8 @@ def __init__(
super().__init__(
lse_mode=lse_mode,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
use_danskin=use_danskin,
implicit_diff=implicit_diff,
**kwargs
Expand Down
12 changes: 11 additions & 1 deletion tests/initializers/quadratic/gw_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def test_explicit_initializer_lr(self):
rank = 10
q_init = initializers_lr.Rank2Initializer(rank)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank, initializer=q_init
rank=rank,
initializer=q_init,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)

assert solver.create_initializer("not used") is q_init
Expand All @@ -67,11 +71,17 @@ def test_gw_better_initialization_helps(self, rng: jax.Array, eps: float):
rank=rank,
initializer="random",
epsilon=eps,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
solver_kmeans = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer="k-means",
epsilon=eps,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)

out_random = solver_random(problem)
Expand Down
12 changes: 10 additions & 2 deletions tests/solvers/quadratic/fgw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def test_fgw_lr_memory(self, rng: jax.Array, jit: bool):
geom_xy = pointcloud.PointCloud(xx, yy)
prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, geom_xy)

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=2)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=2, min_iterations=0, inner_iterations=10, max_iterations=2000
)
if jit:
solver = jax.jit(solver)

Expand Down Expand Up @@ -262,7 +264,13 @@ def test_fgw_lr_generic_cost_matrix(
lr_prob = prob.to_low_rank()
assert lr_prob.is_low_rank

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=5, epsilon=10.0)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=5,
epsilon=10.0,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
out = solver(prob)

assert solver.rank == 5
Expand Down
42 changes: 35 additions & 7 deletions tests/solvers/quadratic/gw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def test_gw_pointcloud(self, balanced: bool, rank: int):
)
if rank > 0:
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank, epsilon=0.0, max_iterations=10
rank=rank,
epsilon=0.0,
max_iterations=10,
)
else:
solver = gromov_wasserstein.GromovWasserstein(
Expand Down Expand Up @@ -324,7 +326,13 @@ def test_gw_lr(self, rng: jax.Array):
geom_yy = pointcloud.PointCloud(y)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=a, b=b)

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=5, epsilon=0.2)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=5,
epsilon=0.2,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
ot_gwlr = solver(prob)
solver = gromov_wasserstein.GromovWasserstein(epsilon=0.2)
ot_gw = solver(prob)
Expand Down Expand Up @@ -352,9 +360,17 @@ def test_gw_lr_matches_fused(self, rng: jax.Array):
geom_xx, geom_yy, geom_xy=geom_xy, fused_penalty=1.3, a=a, b=b
)

solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=6, min_iterations=0, inner_iterations=10, max_iterations=2000
)
ot_gwlr = solver(prob)
solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6, epsilon=1e-1)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=6,
epsilon=1e-1,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
ot_gwlreps = solver(prob)
solver = gromov_wasserstein.GromovWasserstein(epsilon=5e-2)
ot_gw = solver(prob)
Expand All @@ -372,7 +388,13 @@ def test_gw_lr_apply(self, axis: int):
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, a=self.a, b=self.b
)
solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=2, epsilon=1e-1)
solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=2,
epsilon=1e-1,
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
out = solver(prob)

arr, matrix = (self.x, out.matrix) if axis == 0 else (self.y, out.matrix.T)
Expand Down Expand Up @@ -430,7 +452,12 @@ def test_gwlr_unbalanced(
)
solver = jax.jit(
gromov_wasserstein_lr.LRGromovWasserstein(
rank=4, epsilon=eps, kwargs_dys={"translation_invariant": ti}
rank=4,
epsilon=eps,
kwargs_dys={"translation_invariant": ti},
min_iterations=0,
inner_iterations=10,
max_iterations=2000
)
)

Expand Down Expand Up @@ -468,8 +495,9 @@ def test_gwlr_unbalanced_matches_balanced(
rank=rank,
epsilon=eps,
initializer="random",
inner_iterations=50,
min_iterations=50,
max_iterations=50
max_iterations=50,
)
)

Expand Down
Loading