Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/epsilon regularization #310

Merged
merged 5 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,32 @@ class Epsilon:
decay: geometric decay factor, smaller than 1.
"""

# TODO(michalk8): directly use the defaults instead of `None`
def __init__(
self,
target: Optional[float] = None,
scale_epsilon: Optional[float] = None,
init: Optional[float] = None,
decay: Optional[float] = None
init: float = 1.0,
decay: float = 1.0
):
self._target_init = .01 if target is None else target
self._scale_epsilon = 1.0 if scale_epsilon is None else scale_epsilon
self._init = 1.0 if init is None else init
self._decay = 1.0 if decay is None else decay
self._target_init = target
self._scale_epsilon = scale_epsilon
self._init = init
self._decay = decay

@property
def target(self) -> float:
"""Return the final regularizer value of scheduler."""
return self._target_init * self._scale_epsilon
target = 5e-2 if self._target_init is None else self._target_init
if self._scale_epsilon is None:
return target
return target * self._scale_epsilon

def at(self, iteration: Optional[int] = 1) -> float:
"""Return (intermediate) regularizer value at a given iteration."""
if iteration is None:
return self.target
# check the decay is smaller than 1.0.
decay = jnp.where(self._decay < 1.0, self._decay, 1.0)
decay = jnp.minimum(self._decay, 1.0)
# the multiple is either 1.0 or a larger init value that is decayed.
multiple = jnp.maximum(self._init * (decay ** iteration), 1.0)
return multiple * self.target
Expand All @@ -78,6 +80,17 @@ def done_at(self, iteration: Optional[int]) -> bool:
"""Return whether the scheduler is done at a given iteration."""
return self.done(self.at(iteration))

def set(self, **kwargs: Any) -> "Epsilon":
"""TODO."""
kwargs = {
"target": self._target_init,
"scale_epsilon": self._scale_epsilon,
"init": self._init,
"decay": self._decay,
**kwargs
}
return Epsilon(**kwargs)

def tree_flatten(self): # noqa: D102
return (
self._target_init, self._scale_epsilon, self._init, self._decay
Expand All @@ -87,11 +100,3 @@ def tree_flatten(self): # noqa: D102
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)

@classmethod
def make(cls, *args: Any, **kwargs: Any) -> "Epsilon":
"""Create or return an Epsilon instance."""
if isinstance(args[0], cls):
return args[0]
else:
return cls(*args, **kwargs)
110 changes: 53 additions & 57 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,15 @@ class Geometry:
costs.
kernel_matrix: jnp.ndarray<float>[num_a, num_b]: a kernel matrix storing n
x m kernel values.
epsilon: a regularization parameter.
epsilon: a regularization parameter. TODO(michalk8): update the docstring
If a :class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler is passed,
other parameters below are ignored in practice. If the
parameter is a float, then this is understood to be the regularization
that is needed, unless ``relative_epsilon`` below is ``True``, in which
case ``epsilon`` is understood as a normalized quantity, to be scaled by
the mean value of the :attr:`cost_matrix`.
the :attr:`mean_cost_matrix`.
relative_epsilon: whether epsilon is passed relative to scale of problem,
here understood as mean value of :attr:`cost_matrix`.
scale_epsilon: the scale multiplier for epsilon.
here understood the value of :attr:`mean_cost_matrix`.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
Expand All @@ -67,7 +66,6 @@ class Geometry:
:attr:`cost_matrix`, see :attr:`src_mask`.
tgt_mask: Mask specifying valid columns when computing some statistics of
:attr:`cost_matrix`, see :attr:`tgt_mask`.
kwargs: additional kwargs for epsilon scheduler.

Note:
When defining a ``Geometry`` through a ``cost_matrix``, it is important to
Expand All @@ -80,58 +78,32 @@ def __init__(
self,
cost_matrix: Optional[jnp.ndarray] = None,
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Union[epsilon_scheduler.Epsilon, float, None] = None,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
relative_epsilon: Optional[bool] = None,
scale_epsilon: Optional[float] = None,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
scale_cost: Union[bool, int, float, Literal['mean', 'max_cost',
'median']] = 1.0,
**kwargs: Any,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
self._cost_matrix = cost_matrix
self._kernel_matrix = kernel_matrix
self._epsilon_init = epsilon

# needed for `copy_epsilon`, because of the `isinstance` check
self._epsilon_init = epsilon if isinstance(
epsilon, epsilon_scheduler.Epsilon
) else epsilon_scheduler.Epsilon(epsilon)
self._relative_epsilon = relative_epsilon
self._scale_epsilon = scale_epsilon

self._scale_cost = "mean" if scale_cost is True else scale_cost

self._src_mask = src_mask
self._tgt_mask = tgt_mask
# Define default dictionary and update it with user's values.
self._kwargs = {**{'init': None, 'decay': None}, **kwargs}

@property
def cost_rank(self) -> Optional[int]:
"""Output rank of cost matrix, if any was provided."""
return None

@property
def scale_epsilon(self) -> float:
"""Compute the scale of the epsilon, potentially based on data."""
if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return 1.0

rel = self._relative_epsilon
trigger = ((self._scale_epsilon is None) and
((rel is None and self._epsilon_init is None) or rel))

if (self._scale_epsilon is None) and (trigger is not None): # for dry run
return jnp.where(
trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0
)
else:
return self._scale_epsilon

@property
def _epsilon(self) -> epsilon_scheduler.Epsilon:
"""Return epsilon scheduler, either passed directly or by building it."""
if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return self._epsilon_init
eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
return epsilon_scheduler.Epsilon.make(
eps, scale_epsilon=self.scale_epsilon, **self._kwargs
)

@property
def cost_matrix(self) -> jnp.ndarray:
"""Cost matrix, recomputed from kernel if only kernel was specified."""
Expand Down Expand Up @@ -164,6 +136,22 @@ def kernel_matrix(self) -> jnp.ndarray:
return jnp.exp(-(self._cost_matrix * self.inv_scale_cost / self.epsilon))
return self._kernel_matrix ** self.inv_scale_cost

@property
def _epsilon(self) -> epsilon_scheduler.Epsilon:
(target, scale_eps, _, _), _ = self._epsilon_init.tree_flatten()
rel = self._relative_epsilon

use_mean_scale = rel is True or (rel is None and target is None)
if scale_eps is None and use_mean_scale:
scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)

if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return self._epsilon_init.set(scale_epsilon=scale_eps)

return epsilon_scheduler.Epsilon(
target=5e-2 if target is None else target, scale_epsilon=scale_eps
)

@property
def epsilon(self) -> float:
"""Epsilon regularization value."""
Expand Down Expand Up @@ -232,11 +220,20 @@ def _set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry":

def copy_epsilon(self, other: 'Geometry') -> "Geometry":
"""Copy the epsilon parameters from another geometry."""
scheduler = other._epsilon
self._epsilon_init = scheduler._target_init
self._relative_epsilon = False
self._scale_epsilon = other.scale_epsilon
return self
other_epsilon = other._epsilon
children, aux_data = self.tree_flatten()

new_children = []
for child in children:
if isinstance(child, epsilon_scheduler.Epsilon):
child = child.set(
target=other_epsilon._target_init,
scale_epsilon=other_epsilon._scale_epsilon
)
new_children.append(child)

aux_data["relative_epsilon"] = False
return type(self).tree_unflatten(aux_data, new_children)

# The functions below are at the core of Sinkhorn iterations, they
# are implemented here in their default form, either in lse (using directly
Expand Down Expand Up @@ -724,10 +721,8 @@ def to_LRCGeometry(
cost_2=cost_2,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
scale_factor=scale,
**self._kwargs
)

def subset(
Expand Down Expand Up @@ -759,7 +754,7 @@ def subset_fn(
return arr

return self._mask_subset_helper(
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True
)

def mask(
Expand Down Expand Up @@ -819,8 +814,7 @@ def _mask_subset_helper(
propagate_mask: bool,
**kwargs: Any,
) -> "Geometry":
(cost, kernel, *children, src_mask, tgt_mask,
kws), aux_data = self.tree_flatten()
(cost, kernel, eps, src_mask, tgt_mask), aux_data = self.tree_flatten()
cost = fn(cost, src_ixs, tgt_ixs)
kernel = fn(kernel, src_ixs, tgt_ixs)
if propagate_mask:
Expand All @@ -831,7 +825,7 @@ def _mask_subset_helper(

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(
aux_data, [cost, kernel] + children + [src_mask, tgt_mask, kws]
aux_data, [cost, kernel, eps, src_mask, tgt_mask]
)

@property
Expand Down Expand Up @@ -903,16 +897,18 @@ def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]],
def tree_flatten(self): # noqa: D102
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._relative_epsilon, self._scale_epsilon, self._src_mask,
self._tgt_mask, self._kwargs
self._src_mask, self._tgt_mask
), {
'scale_cost': self._scale_cost
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
*args, kwargs = children
return cls(*args, **kwargs, **aux_data)
cost, kernel, eps, src_mask, tgt_mask = children
return cls(
cost, kernel, eps, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data
)


def is_affine(fn) -> bool:
Expand Down
11 changes: 4 additions & 7 deletions src/ott/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def __init__(
"crank_nicolson"] = "backward_euler",
directed: bool = False,
normalize: bool = False,
tol: float = -1.,
tol: float = -1.0,
**kwargs: Any
):
assert ((graph is None and laplacian is not None) or
(laplacian is None and graph is not None)), \
"Please provide a graph or a symmetric graph Laplacian."
# arbitrary epsilon; can't use `None` as `mean_cost_matrix` would be used
super().__init__(epsilon=1., **kwargs)
self._graph = graph
self._lap = laplacian
Expand Down Expand Up @@ -354,21 +353,19 @@ def marginal_from_potentials(
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self._graph, self._lap, self.solver], {
"t": self._t,
return [self._graph, self._lap, self.solver, self._t], {
"n_steps": self.n_steps,
"numerical_scheme": self.numerical_scheme,
"directed": self.directed,
"normalize": self.normalize,
"tol": self._tol,
**self._kwargs,
}

@classmethod
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "Graph":
graph, laplacian, solver = children
obj = cls(graph=graph, laplacian=laplacian, **aux_data)
graph, laplacian, solver, t = children
obj = cls(graph=graph, laplacian=laplacian, t=t, **aux_data)
obj._solver = solver
return obj
12 changes: 5 additions & 7 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def geometries(self) -> List[geometry.Geometry]:
):
x_values = self.x[dimension][:, jnp.newaxis]
geom = pointcloud.PointCloud(
x_values, cost_fn=cost_fn, epsilon=self._epsilon_init
x_values,
cost_fn=cost_fn,
epsilon=self._epsilon_init,
)
geometries.append(geom)
return geometries
Expand Down Expand Up @@ -345,7 +347,7 @@ def dtype(self) -> jnp.dtype: # noqa: D102
return self.x[0].dtype

def tree_flatten(self): # noqa: D102
return (self.x, self.cost_fns, self._epsilon), self.kwargs
return (self.x, self.cost_fns, self._epsilon_init), self.kwargs

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
Expand Down Expand Up @@ -376,8 +378,7 @@ def to_LRCGeometry(
Returns:
:class:`~ott.geometry.low_rank.LRCGeometry` object.
"""
cost_1 = []
cost_2 = []
cost_1, cost_2 = [], []
for dimension, geom in enumerate(self.geometries):
# An overall low-rank conversion of the cost matrix on a grid, to an
# object of :class:`~ott.geometry.low_rank.LRCGeometry`, necesitates an
Expand Down Expand Up @@ -408,10 +409,7 @@ def to_LRCGeometry(
cost_2=cost_2,
scale_factor=scale,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale=self._scale_epsilon,
scale_cost=self._scale_cost,
src_mask=self.src_mask,
tgt_mask=self.tgt_mask,
**self._kwargs
)
Loading