Skip to content

Commit

Permalink
more precise pydocs for #376 (#377)
Browse files Browse the repository at this point in the history
* pydocs: more precise unbalanced

* typo

* typo
  • Loading branch information
marcocuturi authored Jun 27, 2023
1 parent dcb6155 commit dd81c00
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,18 @@ def primal_cost(self) -> float:
def ent_reg_cost(self) -> float:
r"""Entropy regularized cost.
This outputs :math:`\langle P^{\star},C\rangle - \varepsilon H(P^{\star}),`
where :math:`P^{\star}` is the coupling returned by the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
This outputs
.. math::
\langle P^{\star},C\rangle - \varepsilon H(P^{\star}) +
\rho_a\text{KL}(P^{\star} 1|a) + \rho_b\text{KL}(1^T P^{\star}|b),
where :math:`P^{\star}, a, b` is the coupling returned by the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` and the two marginal weight
vectors; :math:`\rho_a=\varepsilon \tau_a / (1-\tau_a)` and
:math:`\rho_b=\varepsilon \tau_b / (1-\tau_b)` are obtained when the problem
is unbalanced from parameters ``tau_a`` and ``tau_b``. Note that the last
two terms vanish in the balanced case, when ``tau_a==tau_b==1``.
"""
ent_a = jnp.sum(jsp.special.entr(self.ot_prob.a))
ent_b = jnp.sum(jsp.special.entr(self.ot_prob.b))
Expand All @@ -333,11 +342,21 @@ def ent_reg_cost(self) -> float:
def kl_reg_cost(self) -> float:
r"""KL regularized OT transport cost.
This outputs :math:`\langle P^{\star}, C \rangle +
\varepsilon KL(P^{\star},ab^T),` where :math:`P^{\star}, a, b` are the
coupling returned by the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`
algorithm and the two marginal weight vectors, respectively.
This coincides with :attr:`reg_ot_cost`.
This outputs
.. math::
\langle P^{\star}, C \rangle + \varepsilon KL(P^{\star},ab^T) +
\rho_a\text{KL}(P^{\star} 1|a) + \rho_b\text{KL}(1^T P^{\star}|b),
where :math:`P^{\star}, a, b` are the coupling returned by the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm and the two
marginal weight vectors, respectively, and
:math:`\rho_a=\varepsilon \tau_a / (1-\tau_a)` and
:math:`\rho_b=\varepsilon \tau_b / (1-\tau_b)` are obtained when the problem
is unbalanced from parameters ``tau_a`` and ``tau_b``. Note that the last
two terms vanish in the balanced case, when ``tau_a==tau_b==1``. This
quantity coincides with :attr:`reg_ot_cost`, which is computed using
dual variables.
"""
return self.reg_ot_cost

Expand Down

0 comments on commit dd81c00

Please sign in to comment.