diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index cb12911bf..e70962149 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -278,6 +278,9 @@ class LRGromovWasserstein(sinkhorn.Sinkhorn): lse_mode: Whether to run computations in LSE or kernel mode. inner_iterations: Number of inner iterations used by the algorithm before re-evaluating progress. + min_iterations: The minimum number of low-rank Sinkhorn iterations carried + out before the error is computed and monitored. + max_iterations: The maximum number of low-rank Sinkhorn iterations. use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only `True` handled at this moment. implicit_diff: Whether to use implicit differentiation. Currently, only @@ -305,9 +308,11 @@ def __init__( "generalized-k-means"], initializers_lr.LRInitializer] = "random", lse_mode: bool = True, - inner_iterations: int = 10, use_danskin: bool = True, implicit_diff: bool = False, + inner_iterations: int = 2_000, + min_iterations: int = 10_000, + max_iterations: int = 100_000, kwargs_dys: Optional[Mapping[str, Any]] = None, kwargs_init: Optional[Mapping[str, Any]] = None, progress_fn: Optional[ProgressCallbackFn_t] = None, @@ -317,6 +322,8 @@ def __init__( super().__init__( lse_mode=lse_mode, inner_iterations=inner_iterations, + min_iterations=min_iterations, + max_iterations=max_iterations, use_danskin=use_danskin, implicit_diff=implicit_diff, **kwargs diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index 6b8985a0c..09346d2ac 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -40,7 +40,11 @@ def test_explicit_initializer_lr(self): rank = 10 q_init = initializers_lr.Rank2Initializer(rank) solver = gromov_wasserstein_lr.LRGromovWasserstein( - rank=rank, initializer=q_init + rank=rank, + initializer=q_init, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 ) assert solver.create_initializer("not used") is q_init @@ -67,11 +71,17 @@ def test_gw_better_initialization_helps(self, rng: jax.Array, eps: float): rank=rank, initializer="random", epsilon=eps, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 ) solver_kmeans = gromov_wasserstein_lr.LRGromovWasserstein( rank=rank, initializer="k-means", epsilon=eps, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 ) out_random = solver_random(problem) diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 3283ae845..58dbb630d 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -227,7 +227,9 @@ def test_fgw_lr_memory(self, rng: jax.Array, jit: bool): geom_xy = pointcloud.PointCloud(xx, yy) prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, geom_xy) - solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=2) + solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=2, min_iterations=0, inner_iterations=10, max_iterations=2000 + ) if jit: solver = jax.jit(solver) @@ -262,7 +264,13 @@ def test_fgw_lr_generic_cost_matrix( lr_prob = prob.to_low_rank() assert lr_prob.is_low_rank - solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=5, epsilon=10.0) + solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=5, + epsilon=10.0, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 + ) out = solver(prob) assert solver.rank == 5 diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index 2ab7af4a2..9ac86f5bd 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -210,7 +210,9 @@ def test_gw_pointcloud(self, balanced: bool, rank: int): ) if rank > 0: solver = gromov_wasserstein_lr.LRGromovWasserstein( - rank=rank, epsilon=0.0, max_iterations=10 + rank=rank, + epsilon=0.0, + max_iterations=10, ) else: solver = gromov_wasserstein.GromovWasserstein( @@ -324,7 +326,13 @@ def test_gw_lr(self, rng: jax.Array): geom_yy = pointcloud.PointCloud(y) prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=a, b=b) - solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=5, epsilon=0.2) + solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=5, + epsilon=0.2, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 + ) ot_gwlr = solver(prob) solver = gromov_wasserstein.GromovWasserstein(epsilon=0.2) ot_gw = solver(prob) @@ -352,9 +360,17 @@ def test_gw_lr_matches_fused(self, rng: jax.Array): geom_xx, geom_yy, geom_xy=geom_xy, fused_penalty=1.3, a=a, b=b ) - solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6) + solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=6, min_iterations=0, inner_iterations=10, max_iterations=2000 + ) ot_gwlr = solver(prob) - solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=6, epsilon=1e-1) + solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=6, + epsilon=1e-1, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 + ) ot_gwlreps = solver(prob) solver = gromov_wasserstein.GromovWasserstein(epsilon=5e-2) ot_gw = solver(prob) @@ -372,7 +388,13 @@ def test_gw_lr_apply(self, axis: int): prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, a=self.a, b=self.b ) - solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=2, epsilon=1e-1) + solver = gromov_wasserstein_lr.LRGromovWasserstein( + rank=2, + epsilon=1e-1, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 + ) out = solver(prob) arr, matrix = (self.x, out.matrix) if axis == 0 else (self.y, out.matrix.T) @@ -430,7 +452,12 @@ def test_gwlr_unbalanced( ) solver = jax.jit( gromov_wasserstein_lr.LRGromovWasserstein( - rank=4, epsilon=eps, kwargs_dys={"translation_invariant": ti} + rank=4, + epsilon=eps, + kwargs_dys={"translation_invariant": ti}, + min_iterations=0, + inner_iterations=10, + max_iterations=2000 ) ) @@ -468,8 +495,9 @@ def test_gwlr_unbalanced_matches_balanced( rank=rank, epsilon=eps, initializer="random", + inner_iterations=50, min_iterations=50, - max_iterations=50 + max_iterations=50, ) )