Skip to content

Commit

Permalink
Fix heat scaling (#481)
Browse files Browse the repository at this point in the history
* scale for the eigval

* eigval for comb. Lap.

* add test for heat kernel

* `t=eps/4.0`

* formatting

* atol rtol test `test_geometry_differentiability`

* formatting & test `order` and `t`
  • Loading branch information
guillaumehu authored Dec 27, 2023
1 parent 38a1625 commit a421e86
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
10 changes: 6 additions & 4 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ def from_graph(

if eigval is None:
eigval = compute_largest_eigenvalue(laplacian, rng)
scaled_laplacian = jax.lax.cond((eigval > 2.0), lambda l: 2.0 * l / eigval,
lambda l: l, laplacian)

scaled_laplacian, eigval = jax.lax.cond((eigval > 2.0), lambda l:
(2.0 * l / eigval, 2.0), lambda l:
(l, eigval), laplacian)

# compute the coeffs of the Chebyshev pols approx using Bessel funcs
chebyshev_coeffs = compute_chebychev_coeff_all(
eigval, t, order, laplacian.dtype
0.5 * eigval, t, order, laplacian.dtype
)

return cls(
Expand Down Expand Up @@ -149,7 +151,7 @@ def apply_kernel(
Kernel applied to ``scaling``.
"""
return expm_multiply(
self.scaled_laplacian, scaling, self.chebyshev_coeffs, self.eigval
self.scaled_laplacian, scaling, self.chebyshev_coeffs, 0.5 * self.eigval
)

@property
Expand Down
37 changes: 31 additions & 6 deletions tests/geometry/geodesic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def gt_geometry(
return geometry.Geometry(cost_matrix=cost, kernel_matrix=kernel, epsilon=1.)


def exact_heat_kernel(G: jnp.ndarray, normalize: bool = False, t: float = 10):
degree = jnp.sum(G, axis=1)
L = jnp.diag(degree) - G
if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
L = inv_sqrt_deg @ L @ inv_sqrt_deg

e, v = jnp.linalg.eigh(L)
e = jnp.clip(e, 0)

return v @ jnp.diag(jnp.exp(-t * e)) @ v.T


class TestGeodesic:

def test_kernel_is_symmetric_positive_definite(
Expand Down Expand Up @@ -99,16 +114,14 @@ def test_kernel_is_symmetric_positive_definite(
# check that the negative eigenvalues are all very small
np.testing.assert_array_less(jnp.abs(neg_eigenvalues), 1e-3)
# internally, the axis is ignored because the kernel is symmetric
np.testing.assert_array_equal(vec0, vec1)
np.testing.assert_array_equal(vec_direct0, vec_direct1)
np.testing.assert_allclose(vec0, vec1, rtol=tol, atol=tol)
np.testing.assert_allclose(vec_direct0, vec_direct1, rtol=tol, atol=tol)

np.testing.assert_allclose(vec0, vec_direct0, rtol=tol, atol=tol)
np.testing.assert_allclose(vec1, vec_direct1, rtol=tol, atol=tol)

# compute the distance matrix and check that it is symmetric
cost_matrix = geom.cost_matrix
np.testing.assert_array_equal(cost_matrix, cost_matrix.T)
# and all dissimilarities are positive
np.testing.assert_allclose(cost_matrix, cost_matrix.T, rtol=tol, atol=tol)
np.testing.assert_array_less(0, cost_matrix)

@pytest.mark.fast.with_args(
Expand Down Expand Up @@ -195,7 +208,7 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
x = jax.random.normal(rng, (n,))

gt_geom = gt_geometry(G, epsilon=eps)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0)

fn = jax.jit(callback) if jit else callback
gt_out = fn(gt_geom)
Expand Down Expand Up @@ -240,3 +253,15 @@ def callback(geom: geodesic.Geodesic) -> float:
expected = callback(geom__finite_right) - callback(geom__finite_left)
actual = 2 * jnp.vdot(v_w, grad_sl)
np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("t", [5, 10, 50])
@pytest.mark.parametrize("order", [20, 30, 40])
def test_heat_approx(self, normalize: bool, t: float, order: int):
G = random_graph(20, p=0.5)
exact = exact_heat_kernel(G, normalize=normalize, t=t)
geom = geodesic.Geodesic.from_graph(
G, t=t, order=order, normalize=normalize
)
approx = geom.apply_kernel(jnp.eye(G.shape[0]))
np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)

0 comments on commit a421e86

Please sign in to comment.