Skip to content

Commit

Permalink
Remove epsilon scheduler in GW (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 authored Mar 28, 2024
1 parent 37250c4 commit 986f049
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
4 changes: 1 addition & 3 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def _create_geometry(
init_transport = jnp.outer(quad_prob.a, quad_prob.b)
marginal_1, marginal_2 = init_transport.sum(1), init_transport.sum(0)

epsilon = quadratic_problem.update_epsilon_unbalanced(
epsilon=epsilon, transport_mass=marginal_1.sum()
)
epsilon *= marginal_1.sum()
unbalanced_correction = quad_prob.cost_unbalanced_correction(
init_transport, marginal_1, marginal_2, epsilon=epsilon
)
Expand Down
27 changes: 8 additions & 19 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.scipy as jsp

from ott import utils
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud
from ott.geometry import geometry, low_rank, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_costs
from ott.types import Transport
Expand Down Expand Up @@ -171,7 +171,7 @@ def cost_unbalanced_correction(
transport_matrix: jnp.ndarray,
marginal_1: jnp.ndarray,
marginal_2: jnp.ndarray,
epsilon: epsilon_scheduler.Epsilon,
epsilon: float,
) -> float:
r"""Calculate cost term from the quadratic divergence when unbalanced.
Expand Down Expand Up @@ -204,13 +204,12 @@ def cost_unbalanced_correction(
"""

def regularizer(tau: float) -> float:
return eps * tau / (1.0 - tau)
return epsilon * tau / (1.0 - tau)

eps = epsilon._target_init
marginal_1loga = jsp.special.xlogy(marginal_1, self.a).sum()
marginal_2logb = jsp.special.xlogy(marginal_2, self.b).sum()

cost = eps * jsp.special.xlogy(transport_matrix, transport_matrix).sum()
cost = epsilon * jsp.special.xlogy(transport_matrix, transport_matrix).sum()
if self.tau_a != 1.0:
cost += regularizer(
self.tau_a
Expand Down Expand Up @@ -269,7 +268,7 @@ def update_lr_geom(
def update_linearization(
self,
transport: Transport,
epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None,
epsilon: Optional[float] = None,
old_transport_mass: float = 1.0,
relative_epsilon: Optional[bool] = None,
) -> linear_problem.LinearProblem:
Expand Down Expand Up @@ -310,11 +309,9 @@ def update_linearization(
transport_matrix = transport.matrix * rescale_factor

if not self.is_balanced:
# Rescales transport for Unbalanced GW according to Sejourne et al. (2021)
transport_mass = jax.lax.stop_gradient(marginal_1.sum())
epsilon = update_epsilon_unbalanced(epsilon, transport_mass)
epsilon *= jax.lax.stop_gradient(marginal_1.sum())
unbalanced_correction = self.cost_unbalanced_correction(
transport_matrix, marginal_1, marginal_2, epsilon
transport_matrix, marginal_1, marginal_2, epsilon=epsilon
)

h1, h2 = self.quad_loss
Expand All @@ -329,7 +326,7 @@ def update_linearization(
geom = geometry.Geometry(
cost_matrix=cost_matrix,
epsilon=epsilon,
relative_epsilon=relative_epsilon
relative_epsilon=relative_epsilon,
)

return linear_problem.LinearProblem(
Expand Down Expand Up @@ -500,14 +497,6 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*geoms, a=a, b=b, **aux_data)


def update_epsilon_unbalanced( # noqa: D103
epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float
) -> epsilon_scheduler.Epsilon:
if not isinstance(epsilon, epsilon_scheduler.Epsilon):
epsilon = epsilon_scheduler.Epsilon(epsilon, scale_epsilon=1.0)
return epsilon.set(scale_epsilon=epsilon._scale_epsilon * transport_mass)


def apply_cost( # noqa: D103
geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int,
fn: quadratic_costs.Loss
Expand Down

0 comments on commit 986f049

Please sign in to comment.