From 62b07e28f25992accf36c664751b6cfbac380296 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 7 Nov 2023 14:08:20 +0100 Subject: [PATCH] Fix PyTree serialization for SinkhornDivergence --- src/ott/tools/sinkhorn_divergence.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 653eaa26d..51de97613 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import MappingProxyType -from typing import Any, Mapping, NamedTuple, Optional, Tuple, Type +from typing import Any, Mapping, Optional, Tuple, Type import jax.numpy as jnp +from ott import utils from ott.geometry import costs, geometry, pointcloud, segment from ott.problems.linear import linear_problem, potentials from ott.solvers import linear @@ -29,7 +30,8 @@ Potentials_t = Tuple[jnp.ndarray, jnp.ndarray] -class SinkhornDivergenceOutput(NamedTuple): # noqa: D101 +@utils.register_pytree_node +class SinkhornDivergenceOutput: # noqa: D101 divergence: float potentials: Tuple[Potentials_t, Potentials_t, Potentials_t] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] @@ -49,6 +51,24 @@ def to_dual_potentials(self) -> "potentials.EntropicPotentials": f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y ) + def tree_flatten_foo(self): # noqa: D102 + return [ + self.divergence, + self.potentials, + self.geoms, + self.a, + self.b, + ], { + "n_iters": self.n_iters, + "converged": self.converged, + "errors": self.errors + } + + @classmethod + def tree_unflatten_foo(cls, aux_data, children): # noqa: D102 + div, pots, geoms, a, b = children + return cls(div, pots, geoms, a=a, b=b, **aux_data) + def sinkhorn_divergence( geom: Type[geometry.Geometry],