Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Jul 20, 2023
1 parent c500a7f commit e7e2318
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,13 @@ def sort(
x_sorted = jax.numpy.sort(x)
Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
the number of target points to which the slice of the ``inputs`` tensor
will be mapped to will be equal to ``topk+1``.
will be mapped to will be equal to ``topk + 1``.
num_targets: if ``topk`` is not specified, a vector of size``num_targets``
is returned. This defines the number of (composite) sorted values computed
from the inputs (each value is a convex combination of values recorded in
Expand All @@ -186,7 +185,7 @@ def sort(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -245,9 +244,9 @@ def ranks(
axis: the axis on which to apply the soft-sorting operator.
target_weights: This vector contains weights (summing to 1) that describe
amount of mass shipped to targets.
num_targets: If `target_weights` is ``None``, ``num_targets`` is considered
to define the number of targets used to rank inputs. Each normalized rank
returned in the output will be a convex combination of
num_targets: If ``target_weights` is ``None``, ``num_targets`` is
considered to define the number of targets used to rank inputs. Each
normalized rank in the output will be a convex combination of
``{1, .., num_targets}/num_targets``. The weight of each of these points
is assumed to be uniform. If neither ``num_targets`` nor
``target_weights`` are specified, ``num_targets`` defaults to the size
Expand All @@ -259,13 +258,14 @@ def ranks(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
An Array of the same shape as the input with soft-rank values
normalized to be in `[0, n-1]` where `n` is `inputs.shape[axis]`.
normalized to be in :math:`[0, n-1]` where :math:`n` is
``inputs.shape[axis]``.
"""
return apply_on_axis(
_ranks, inputs, axis, num_targets, target_weights, **kwargs
Expand All @@ -278,7 +278,7 @@ def topk_mask(
k: int = 1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Soft top-$k$ selection mask.
r"""Soft :math:`\text{top-}k` selection mask.
For instance:
Expand Down Expand Up @@ -308,12 +308,12 @@ def topk_mask(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
The soft mask.
The soft :math:`\text{top-}k` selection mask.
"""
num_points = inputs.shape[axis]
assert k < num_points, (
Expand Down Expand Up @@ -379,7 +379,7 @@ def quantile(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -467,7 +467,7 @@ def quantile_normalization(
axis: int = -1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Renormalize inputs so that its quantiles match those of targets/weights.
r"""Re-normalize inputs so that its quantiles match those of targets/weights.
Quantile normalization rearranges the values in inputs to values that match
the distribution of values described in the discrete distribution ``targets``
Expand All @@ -491,7 +491,7 @@ def quantile_normalization(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -544,7 +544,7 @@ def sort_with(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -587,7 +587,7 @@ def quantize(
axis: int = -1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Soft quantizes an input according using num_levels values along axis.
r"""Soft quantizes an input according using ``num_levels`` values along axis.
The quantization operator consists in concentrating several values around
a few predefined ``num_levels``. The soft quantization operator proposed here
Expand All @@ -612,7 +612,7 @@ def quantize(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down

0 comments on commit e7e2318

Please sign in to comment.