Skip to content

Commit

Permalink
Polish the docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Jul 11, 2024
1 parent b87de33 commit 2c108b1
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,20 @@ def twist_operator(
return vec + jax.grad(self.h_legendre)(-dual_vec)
return vec - jax.grad(self.h_legendre)(dual_vec)

def transport_map(self, f: Func) -> Callable[[jnp.ndarray], jnp.ndarray]:
r"""Get an optimal transport map.
def transport_map(self,
f: Func) -> Callable[[jnp.ndarray, bool, Any], jnp.ndarray]:
r"""Get an optimal transport map for a concave function :math:`f`.
Uses Theorem 1.17 from :cite:`santambrogio:15` to define an OT map, e.g. in
the forward case :math:`x - (\nabla h^*) \circ \nabla f^h(x)`, where
:math:`h^*` is the Legendre transform of :math:`h` and :math:`f^h`
is the h-transform of a concave function :math:`f`.
the forward case :math:`x - (\nabla h^*) \circ \nabla \bar f^h(x)`, where
:math:`h^*` is the Legendre transform of :math:`h` and :math:`\bar f^h`
is the :meth:`h_transform` of a concave function :math:`f`.
Args:
f: Concave function.
Returns:
The transport map.
The transport map with a signature ``(x, forward, **kwargs)``.
"""

def transport(
Expand Down

0 comments on commit 2c108b1

Please sign in to comment.