Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove epsilon scheduler in GW #508

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def _create_geometry(
init_transport = jnp.outer(quad_prob.a, quad_prob.b)
marginal_1, marginal_2 = init_transport.sum(1), init_transport.sum(0)

epsilon = quadratic_problem.update_epsilon_unbalanced(
epsilon=epsilon, transport_mass=marginal_1.sum()
)
epsilon *= marginal_1.sum()
unbalanced_correction = quad_prob.cost_unbalanced_correction(
init_transport, marginal_1, marginal_2, epsilon=epsilon
)
Expand Down
27 changes: 8 additions & 19 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.scipy as jsp

from ott import utils
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud
from ott.geometry import geometry, low_rank, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_costs
from ott.types import Transport
Expand Down Expand Up @@ -171,7 +171,7 @@ def cost_unbalanced_correction(
transport_matrix: jnp.ndarray,
marginal_1: jnp.ndarray,
marginal_2: jnp.ndarray,
epsilon: epsilon_scheduler.Epsilon,
epsilon: float,
) -> float:
r"""Calculate cost term from the quadratic divergence when unbalanced.

Expand Down Expand Up @@ -204,13 +204,12 @@ def cost_unbalanced_correction(
"""

def regularizer(tau: float) -> float:
return eps * tau / (1.0 - tau)
return epsilon * tau / (1.0 - tau)

eps = epsilon._target_init
marginal_1loga = jsp.special.xlogy(marginal_1, self.a).sum()
marginal_2logb = jsp.special.xlogy(marginal_2, self.b).sum()

cost = eps * jsp.special.xlogy(transport_matrix, transport_matrix).sum()
cost = epsilon * jsp.special.xlogy(transport_matrix, transport_matrix).sum()
if self.tau_a != 1.0:
cost += regularizer(
self.tau_a
Expand Down Expand Up @@ -269,7 +268,7 @@ def update_lr_geom(
def update_linearization(
self,
transport: Transport,
epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None,
epsilon: Optional[float] = None,
old_transport_mass: float = 1.0,
relative_epsilon: Optional[bool] = None,
) -> linear_problem.LinearProblem:
Expand Down Expand Up @@ -310,11 +309,9 @@ def update_linearization(
transport_matrix = transport.matrix * rescale_factor

if not self.is_balanced:
# Rescales transport for Unbalanced GW according to Sejourne et al. (2021)
transport_mass = jax.lax.stop_gradient(marginal_1.sum())
epsilon = update_epsilon_unbalanced(epsilon, transport_mass)
epsilon *= jax.lax.stop_gradient(marginal_1.sum())
unbalanced_correction = self.cost_unbalanced_correction(
transport_matrix, marginal_1, marginal_2, epsilon
transport_matrix, marginal_1, marginal_2, epsilon=epsilon
)

h1, h2 = self.quad_loss
Expand All @@ -329,7 +326,7 @@ def update_linearization(
geom = geometry.Geometry(
cost_matrix=cost_matrix,
epsilon=epsilon,
relative_epsilon=relative_epsilon
relative_epsilon=relative_epsilon,
)

return linear_problem.LinearProblem(
Expand Down Expand Up @@ -500,14 +497,6 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*geoms, a=a, b=b, **aux_data)


def update_epsilon_unbalanced( # noqa: D103
epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float
) -> epsilon_scheduler.Epsilon:
if not isinstance(epsilon, epsilon_scheduler.Epsilon):
epsilon = epsilon_scheduler.Epsilon(epsilon, scale_epsilon=1.0)
return epsilon.set(scale_epsilon=epsilon._scale_epsilon * transport_mass)


def apply_cost( # noqa: D103
geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int,
fn: quadratic_costs.Loss
Expand Down
Loading