From db9f58a13755901e877ececa6d2388dd9d62d6b1 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Fri, 7 Jul 2023 00:22:25 +0200 Subject: [PATCH] fix #375, `converged` flag when predefined number of iterations. (#386) * fix 375 * handle conversion of `threshold` in SinkhornOutput * fix bug in converged when static_b is True * take Michal comments into account + pydocs * shape issue in the convergence[0] --- src/ott/math/fixed_point_loop.py | 24 +++---- src/ott/solvers/linear/sinkhorn.py | 75 ++++++++++++++++++---- src/ott/tools/sinkhorn_divergence.py | 6 +- tests/solvers/linear/sinkhorn_misc_test.py | 11 ++-- tests/solvers/linear/sinkhorn_test.py | 19 ++++++ 5 files changed, 105 insertions(+), 30 deletions(-) diff --git a/src/ott/math/fixed_point_loop.py b/src/ott/math/fixed_point_loop.py index f27ec835d..9034eba62 100644 --- a/src/ott/math/fixed_point_loop.py +++ b/src/ott/math/fixed_point_loop.py @@ -27,31 +27,31 @@ def fixpoint_iter( ): """Implementation of a fixed point loop. - This fixed point loop iterator applies body_fn to a tuple - (iteration, constants, state, compute_error) to output a new state, using + This fixed point loop iterator applies ``body_fn`` to a tuple + ``(iteration, constants, state, compute_error)`` to output a new state, using context provided in iteration and constants. - body_fn is iterated (inner_iterations -1) times, and one last time with the - compute_error flag indicating that additional computational effort can be - spent on recalculating the latest error (errors are stored as the first - element of the state tuple). + ``body_fn`` is iterated (inner_iterations -1) times, and one last time with + the ``compute_error`` flag to ``True``, indicating that additional + computational effort can be spent on recalculating the latest error + (``errors`` are stored as the first element of the state tuple). - upon termination of these inner_iterations, the loop is continued if iteration - is smaller than min_iterations, stopped if equal/larger than max_iterations, - and interrupted if cond_fn returns False. + upon termination of these ``inner_iterations``, the loop is continued if + iteration is smaller than ``min_iterations``, stopped if equal/larger than + ``max_iterations``, and interrupted if ``cond_fn`` returns False. Args: cond_fn : termination condition function body_fn : body loop instructions min_iterations : lower bound on the total amount of fixed point iterations max_iterations : upper bound on the total amount of fixed point iterations - inner_iterations : number of iterations body_fn will be executed - successively before calling cond_fn. + inner_iterations : number of iterations ``body_fn`` will be executed + successively before calling ``cond_fn``. constants : constant (during loop) parameters passed on to body state : state variable Returns: - outputs state returned by body_fn upon termination. + outputs state returned by ``body_fn`` upon termination. """ # noqa: D401 # If number of minimal iterations matches maximal number, force a scan instead # of a while loop. diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index d74d065d9..80a995bab 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -284,13 +284,47 @@ def ent_reg_cost( class SinkhornOutput(NamedTuple): - """Implements the problems.Transport interface, for a Sinkhorn solution.""" + """Holds the output of a Sinkhorn solver applied to a problem. + + Objects of this class contain both solutions and problem definition of a + regularized OT problem, along several methods that can be used to access its + content, to, for instance, materialize an OT matrix or apply it to a vector + (without having to materialize it when not needed). + + Args: + f: dual variables vector of size ``ot.prob.shape[0]`` returned by Sinkhorn + g: dual variables vector of size ``ot.prob.shape[1]`` returned by Sinkhorn + errors: vector or errors, along iterations. This vector is of size + ``max_iterations // inner_iterations`` where those were the parameters + passed on to the :class:`ott.solvers.linear.sinkhorn.Sinkhorn` solver. + For each entry indexed at ``i``, ``errors[i]`` can be either a real + nonnegative value (meaning the algorithm recorded that error at the + ``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the + algorithm computed that iteration but did not compute its error, because, + for instance, ``i < min_iterations // inner_iterations``), or a ``-1``, + meaning that execution was terminated before that iteration, because the + criterion was found to be smaller than ``threshold``. + reg_ot_cost: the regularized optimal transport cost. By default this is + the linear contribution + KL term. See + :meth:`ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`, + :meth:`ott.solvers.linear.sinkhorn.SinkhornOutput.primal_cost` and + :meth:`ott.solvers.linear.sinkhorn.SinkhornOutput.dual_cost` for other + objective values. + ot_prob: stores the definition of the OT problem, including geometry, + marginals, unbalanced regularizers, etc. + threshold: convergence threshold used to control the termination of the + algorithm. + converged: whether the output corresponds to a solution whose error is + below the convergence threshold. + """ f: Optional[jnp.ndarray] = None g: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None reg_ot_cost: Optional[float] = None ot_prob: Optional[linear_problem.LinearProblem] = None + threshold: Optional[jnp.ndarray] = None + converged: Optional[bool] = None def set(self, **kwargs: Any) -> "SinkhornOutput": """Return a copy of self, with potential overwrites.""" @@ -306,7 +340,7 @@ def set_cost( # noqa: D102 @property def dual_cost(self) -> jnp.ndarray: - """Return transport cost in dual form of current solution.""" + """Return dual transport cost, without considering regularizer.""" a, b = self.ot_prob.a, self.ot_prob.b dual_cost = jnp.sum(jnp.where(a > 0.0, a * self.f, 0)) dual_cost += jnp.sum(jnp.where(b > 0.0, b * self.g, 0)) @@ -314,7 +348,7 @@ def dual_cost(self) -> jnp.ndarray: @property def primal_cost(self) -> float: - """Return transport cost of current solution at geometry.""" + """Return transport cost of current transport solution at geometry.""" return self.transport_cost_at_geom(other_geom=self.geom) @property @@ -404,14 +438,6 @@ def b(self) -> jnp.ndarray: # noqa: D102 def linear_output(self) -> bool: # noqa: D102 return True - @property - def converged(self) -> bool: # noqa: D102 - if self.errors is None: - return False - return jnp.logical_and( - jnp.any(self.errors == -1), jnp.all(jnp.isfinite(self.errors)) - ) - # TODO(michalk8): this should be always present @property def n_iters(self) -> int: # noqa: D102 @@ -944,7 +970,10 @@ def one_iteration( # re-computes error if compute_error is True, else set it to inf. err = jax.lax.cond( - jnp.logical_and(compute_error, iteration >= self.min_iterations), + jnp.logical_or( + iteration == self.max_iterations - 1, + jnp.logical_and(compute_error, iteration >= self.min_iterations) + ), lambda state, prob: state.solution_error( prob, self.norm_error, @@ -1038,7 +1067,27 @@ def output_from_state( if self.recenter_potentials: f, g = state.recenter(f, g, ot_prob=ot_prob) - return SinkhornOutput(f=f, g=g, errors=state.errors[:, 0]) + # By convention, the algorithm is said to have converged if the algorithm + # has not nan'ed during iterations (notice some errors might be infinite, + # this convention is used when the error is not recomputed), and if the + # last recorded error is lower than the threshold. Note that this will be + # the case if either the algorithm terminated earlier (in which case the + # last state.errors[-1] = -1 by convention) or if the algorithm carried out + # the maximal number of iterations and its last recorded error (at -1 + # position) is lower than the threshold. + + converged = jnp.logical_and( + jnp.logical_not(jnp.any(jnp.isnan(state.errors))), + state.errors[-1] < self.threshold + )[0] + + return SinkhornOutput( + f=f, + g=g, + errors=state.errors[:, 0], + threshold=jnp.array(self.threshold), + converged=converged + ) @property def norm_error(self) -> Tuple[int, ...]: diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 11cfacc6e..d822fa0cf 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -166,7 +166,11 @@ def _sinkhorn_divergence( out_xy = sinkhorn.solve(geometry_xy, a, b, **kwargs) out_xx = sinkhorn.solve(geometry_xx, a, a, **kwargs_symmetric) if geometry_yy is None: - out_yy = sinkhorn.SinkhornOutput(errors=jnp.array([]), reg_ot_cost=0.0) + # Create dummy output, corresponds to scenario where static_b is True. + # This choice ensures that `converged`` of this dummy output is True. + out_yy = sinkhorn.SinkhornOutput( + errors=jnp.array([-jnp.inf]), reg_ot_cost=0.0, threshold=0.0 + ) else: out_yy = sinkhorn.solve(geometry_yy, b, b, **kwargs_symmetric) diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index 58dd320b8..860af106a 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -351,12 +351,11 @@ def assert_output_close( """Assert SinkhornOutputs are close.""" x = tuple(a for a in x if (a is not None and isinstance(a, jnp.ndarray))) y = tuple(a for a in y if (a is not None and isinstance(a, jnp.ndarray))) - return chex.assert_tree_all_close(x, y, atol=1e-6, rtol=0) + return chex.assert_trees_all_close(x, y, atol=1e-6, rtol=0) geom = self.geometry jitted_result = jax.jit(sinkhorn.solve)(geom, a=self.a, b=self.b) non_jitted_result = sinkhorn.solve(geom, a=self.a, b=self.b) - assert_output_close(non_jitted_result, jitted_result) @pytest.mark.parametrize("implicit", [False, True]) @@ -382,5 +381,9 @@ def val_grad(a: jnp.ndarray, x: jnp.ndarray) -> float: jitted_loss, jitted_grad = jax.jit(val_grad)(self.a, self.x) non_jitted_loss, non_jitted_grad = val_grad(self.a, self.x) - chex.assert_tree_all_close(jitted_loss, non_jitted_loss, atol=1e-6, rtol=0.) - chex.assert_tree_all_close(jitted_grad, non_jitted_grad, atol=1e-6, rtol=0.) + chex.assert_trees_all_close( + jitted_loss, non_jitted_loss, atol=1e-6, rtol=0. + ) + chex.assert_trees_all_close( + jitted_grad, non_jitted_grad, atol=1e-6, rtol=0. + ) diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index a91ac4108..51d141065 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -187,6 +187,25 @@ def test_euclidean_point_cloud_min_iter(self): assert errors[2] == jnp.inf assert errors[3] > 0 + @pytest.mark.fast() + def test_euclidean_point_cloud_scan_loop(self): + """Testing the scan loop behavior.""" + threshold = 1e-3 + geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) + out = sinkhorn.solve( + geom, + a=self.a, + b=self.b, + threshold=threshold, + min_iterations=50, + max_iterations=50 + ) + # Test converged flag is True despite running in scan mode + assert out.converged + # Test last error recomputed at the final iteration, and below threshold. + assert out.errors[-1] > 0 + assert out.errors[-1] < threshold + def test_geom_vs_point_cloud(self): """Two point clouds vs. simple cost_matrix execution of Sinkhorn.""" geom_1 = pointcloud.PointCloud(self.x, self.y)