Skip to content

Commit

Permalink
fix #375, converged flag when predefined number of iterations. (#386)
Browse files Browse the repository at this point in the history
* fix 375

* handle conversion of `threshold` in SinkhornOutput

* fix bug in converged when static_b is True

* take Michal comments into account + pydocs

* shape issue in the convergence[0]
  • Loading branch information
marcocuturi authored Jul 6, 2023
1 parent 5f2a73b commit db9f58a
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 30 deletions.
24 changes: 12 additions & 12 deletions src/ott/math/fixed_point_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,31 @@ def fixpoint_iter(
):
"""Implementation of a fixed point loop.
This fixed point loop iterator applies body_fn to a tuple
(iteration, constants, state, compute_error) to output a new state, using
This fixed point loop iterator applies ``body_fn`` to a tuple
``(iteration, constants, state, compute_error)`` to output a new state, using
context provided in iteration and constants.
body_fn is iterated (inner_iterations -1) times, and one last time with the
compute_error flag indicating that additional computational effort can be
spent on recalculating the latest error (errors are stored as the first
element of the state tuple).
``body_fn`` is iterated (inner_iterations -1) times, and one last time with
the ``compute_error`` flag to ``True``, indicating that additional
computational effort can be spent on recalculating the latest error
(``errors`` are stored as the first element of the state tuple).
upon termination of these inner_iterations, the loop is continued if iteration
is smaller than min_iterations, stopped if equal/larger than max_iterations,
and interrupted if cond_fn returns False.
upon termination of these ``inner_iterations``, the loop is continued if
iteration is smaller than ``min_iterations``, stopped if equal/larger than
``max_iterations``, and interrupted if ``cond_fn`` returns False.
Args:
cond_fn : termination condition function
body_fn : body loop instructions
min_iterations : lower bound on the total amount of fixed point iterations
max_iterations : upper bound on the total amount of fixed point iterations
inner_iterations : number of iterations body_fn will be executed
successively before calling cond_fn.
inner_iterations : number of iterations ``body_fn`` will be executed
successively before calling ``cond_fn``.
constants : constant (during loop) parameters passed on to body
state : state variable
Returns:
outputs state returned by body_fn upon termination.
outputs state returned by ``body_fn`` upon termination.
""" # noqa: D401
# If number of minimal iterations matches maximal number, force a scan instead
# of a while loop.
Expand Down
75 changes: 62 additions & 13 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,47 @@ def ent_reg_cost(


class SinkhornOutput(NamedTuple):
"""Implements the problems.Transport interface, for a Sinkhorn solution."""
"""Holds the output of a Sinkhorn solver applied to a problem.
Objects of this class contain both solutions and problem definition of a
regularized OT problem, along several methods that can be used to access its
content, to, for instance, materialize an OT matrix or apply it to a vector
(without having to materialize it when not needed).
Args:
f: dual variables vector of size ``ot.prob.shape[0]`` returned by Sinkhorn
g: dual variables vector of size ``ot.prob.shape[1]`` returned by Sinkhorn
errors: vector or errors, along iterations. This vector is of size
``max_iterations // inner_iterations`` where those were the parameters
passed on to the :class:`ott.solvers.linear.sinkhorn.Sinkhorn` solver.
For each entry indexed at ``i``, ``errors[i]`` can be either a real
nonnegative value (meaning the algorithm recorded that error at the
``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the
algorithm computed that iteration but did not compute its error, because,
for instance, ``i < min_iterations // inner_iterations``), or a ``-1``,
meaning that execution was terminated before that iteration, because the
criterion was found to be smaller than ``threshold``.
reg_ot_cost: the regularized optimal transport cost. By default this is
the linear contribution + KL term. See
:meth:`ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`,
:meth:`ott.solvers.linear.sinkhorn.SinkhornOutput.primal_cost` and
:meth:`ott.solvers.linear.sinkhorn.SinkhornOutput.dual_cost` for other
objective values.
ot_prob: stores the definition of the OT problem, including geometry,
marginals, unbalanced regularizers, etc.
threshold: convergence threshold used to control the termination of the
algorithm.
converged: whether the output corresponds to a solution whose error is
below the convergence threshold.
"""

f: Optional[jnp.ndarray] = None
g: Optional[jnp.ndarray] = None
errors: Optional[jnp.ndarray] = None
reg_ot_cost: Optional[float] = None
ot_prob: Optional[linear_problem.LinearProblem] = None
threshold: Optional[jnp.ndarray] = None
converged: Optional[bool] = None

def set(self, **kwargs: Any) -> "SinkhornOutput":
"""Return a copy of self, with potential overwrites."""
Expand All @@ -306,15 +340,15 @@ def set_cost( # noqa: D102

@property
def dual_cost(self) -> jnp.ndarray:
"""Return transport cost in dual form of current solution."""
"""Return dual transport cost, without considering regularizer."""
a, b = self.ot_prob.a, self.ot_prob.b
dual_cost = jnp.sum(jnp.where(a > 0.0, a * self.f, 0))
dual_cost += jnp.sum(jnp.where(b > 0.0, b * self.g, 0))
return dual_cost

@property
def primal_cost(self) -> float:
"""Return transport cost of current solution at geometry."""
"""Return transport cost of current transport solution at geometry."""
return self.transport_cost_at_geom(other_geom=self.geom)

@property
Expand Down Expand Up @@ -404,14 +438,6 @@ def b(self) -> jnp.ndarray: # noqa: D102
def linear_output(self) -> bool: # noqa: D102
return True

@property
def converged(self) -> bool: # noqa: D102
if self.errors is None:
return False
return jnp.logical_and(
jnp.any(self.errors == -1), jnp.all(jnp.isfinite(self.errors))
)

# TODO(michalk8): this should be always present
@property
def n_iters(self) -> int: # noqa: D102
Expand Down Expand Up @@ -944,7 +970,10 @@ def one_iteration(

# re-computes error if compute_error is True, else set it to inf.
err = jax.lax.cond(
jnp.logical_and(compute_error, iteration >= self.min_iterations),
jnp.logical_or(
iteration == self.max_iterations - 1,
jnp.logical_and(compute_error, iteration >= self.min_iterations)
),
lambda state, prob: state.solution_error(
prob,
self.norm_error,
Expand Down Expand Up @@ -1038,7 +1067,27 @@ def output_from_state(
if self.recenter_potentials:
f, g = state.recenter(f, g, ot_prob=ot_prob)

return SinkhornOutput(f=f, g=g, errors=state.errors[:, 0])
# By convention, the algorithm is said to have converged if the algorithm
# has not nan'ed during iterations (notice some errors might be infinite,
# this convention is used when the error is not recomputed), and if the
# last recorded error is lower than the threshold. Note that this will be
# the case if either the algorithm terminated earlier (in which case the
# last state.errors[-1] = -1 by convention) or if the algorithm carried out
# the maximal number of iterations and its last recorded error (at -1
# position) is lower than the threshold.

converged = jnp.logical_and(
jnp.logical_not(jnp.any(jnp.isnan(state.errors))),
state.errors[-1] < self.threshold
)[0]

return SinkhornOutput(
f=f,
g=g,
errors=state.errors[:, 0],
threshold=jnp.array(self.threshold),
converged=converged
)

@property
def norm_error(self) -> Tuple[int, ...]:
Expand Down
6 changes: 5 additions & 1 deletion src/ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def _sinkhorn_divergence(
out_xy = sinkhorn.solve(geometry_xy, a, b, **kwargs)
out_xx = sinkhorn.solve(geometry_xx, a, a, **kwargs_symmetric)
if geometry_yy is None:
out_yy = sinkhorn.SinkhornOutput(errors=jnp.array([]), reg_ot_cost=0.0)
# Create dummy output, corresponds to scenario where static_b is True.
# This choice ensures that `converged`` of this dummy output is True.
out_yy = sinkhorn.SinkhornOutput(
errors=jnp.array([-jnp.inf]), reg_ot_cost=0.0, threshold=0.0
)
else:
out_yy = sinkhorn.solve(geometry_yy, b, b, **kwargs_symmetric)

Expand Down
11 changes: 7 additions & 4 deletions tests/solvers/linear/sinkhorn_misc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,11 @@ def assert_output_close(
"""Assert SinkhornOutputs are close."""
x = tuple(a for a in x if (a is not None and isinstance(a, jnp.ndarray)))
y = tuple(a for a in y if (a is not None and isinstance(a, jnp.ndarray)))
return chex.assert_tree_all_close(x, y, atol=1e-6, rtol=0)
return chex.assert_trees_all_close(x, y, atol=1e-6, rtol=0)

geom = self.geometry
jitted_result = jax.jit(sinkhorn.solve)(geom, a=self.a, b=self.b)
non_jitted_result = sinkhorn.solve(geom, a=self.a, b=self.b)

assert_output_close(non_jitted_result, jitted_result)

@pytest.mark.parametrize("implicit", [False, True])
Expand All @@ -382,5 +381,9 @@ def val_grad(a: jnp.ndarray, x: jnp.ndarray) -> float:
jitted_loss, jitted_grad = jax.jit(val_grad)(self.a, self.x)
non_jitted_loss, non_jitted_grad = val_grad(self.a, self.x)

chex.assert_tree_all_close(jitted_loss, non_jitted_loss, atol=1e-6, rtol=0.)
chex.assert_tree_all_close(jitted_grad, non_jitted_grad, atol=1e-6, rtol=0.)
chex.assert_trees_all_close(
jitted_loss, non_jitted_loss, atol=1e-6, rtol=0.
)
chex.assert_trees_all_close(
jitted_grad, non_jitted_grad, atol=1e-6, rtol=0.
)
19 changes: 19 additions & 0 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,25 @@ def test_euclidean_point_cloud_min_iter(self):
assert errors[2] == jnp.inf
assert errors[3] > 0

@pytest.mark.fast()
def test_euclidean_point_cloud_scan_loop(self):
"""Testing the scan loop behavior."""
threshold = 1e-3
geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1)
out = sinkhorn.solve(
geom,
a=self.a,
b=self.b,
threshold=threshold,
min_iterations=50,
max_iterations=50
)
# Test converged flag is True despite running in scan mode
assert out.converged
# Test last error recomputed at the final iteration, and below threshold.
assert out.errors[-1] > 0
assert out.errors[-1] < threshold

def test_geom_vs_point_cloud(self):
"""Two point clouds vs. simple cost_matrix execution of Sinkhorn."""
geom_1 = pointcloud.PointCloud(self.x, self.y)
Expand Down

0 comments on commit db9f58a

Please sign in to comment.