Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix heat scaling #481

Merged
merged 8 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ def from_graph(

if eigval is None:
eigval = compute_largest_eigenvalue(laplacian, rng)
scaled_laplacian = jax.lax.cond((eigval > 2.0), lambda l: 2.0 * l / eigval,
lambda l: l, laplacian)

scaled_laplacian, eigval = jax.lax.cond((eigval > 2.0), lambda l:
(2.0 * l / eigval, 2.0), lambda l:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
(l, eigval), laplacian)

# compute the coeffs of the Chebyshev pols approx using Bessel funcs
chebyshev_coeffs = compute_chebychev_coeff_all(
eigval, t, order, laplacian.dtype
0.5 * eigval, t, order, laplacian.dtype
)

return cls(
Expand Down Expand Up @@ -149,7 +151,7 @@ def apply_kernel(
Kernel applied to ``scaling``.
"""
return expm_multiply(
self.scaled_laplacian, scaling, self.chebyshev_coeffs, self.eigval
self.scaled_laplacian, scaling, self.chebyshev_coeffs, 0.5 * self.eigval
)

@property
Expand Down
37 changes: 31 additions & 6 deletions tests/geometry/geodesic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def gt_geometry(
return geometry.Geometry(cost_matrix=cost, kernel_matrix=kernel, epsilon=1.)


def exact_heat_kernel(G: jnp.ndarray, normalize: bool = False, t: float = 10):
degree = jnp.sum(G, axis=1)
L = jnp.diag(degree) - G
if normalize:
inv_sqrt_deg = jnp.diag(
jnp.where(degree > 0.0, 1.0 / jnp.sqrt(degree), 0.0)
)
L = inv_sqrt_deg @ L @ inv_sqrt_deg

e, v = jnp.linalg.eigh(L)
e = jnp.clip(e, 0)

return v @ jnp.diag(jnp.exp(-t * e)) @ v.T


class TestGeodesic:

def test_kernel_is_symmetric_positive_definite(
Expand Down Expand Up @@ -99,16 +114,14 @@ def test_kernel_is_symmetric_positive_definite(
# check that the negative eigenvalues are all very small
np.testing.assert_array_less(jnp.abs(neg_eigenvalues), 1e-3)
# internally, the axis is ignored because the kernel is symmetric
np.testing.assert_array_equal(vec0, vec1)
np.testing.assert_array_equal(vec_direct0, vec_direct1)
np.testing.assert_allclose(vec0, vec1, rtol=tol, atol=tol)
np.testing.assert_allclose(vec_direct0, vec_direct1, rtol=tol, atol=tol)

np.testing.assert_allclose(vec0, vec_direct0, rtol=tol, atol=tol)
np.testing.assert_allclose(vec1, vec_direct1, rtol=tol, atol=tol)

# compute the distance matrix and check that it is symmetric
cost_matrix = geom.cost_matrix
np.testing.assert_array_equal(cost_matrix, cost_matrix.T)
guillaumehu marked this conversation as resolved.
Show resolved Hide resolved
# and all dissimilarities are positive
np.testing.assert_allclose(cost_matrix, cost_matrix.T, rtol=tol, atol=tol)
np.testing.assert_array_less(0, cost_matrix)

@pytest.mark.fast.with_args(
Expand Down Expand Up @@ -195,7 +208,7 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
x = jax.random.normal(rng, (n,))

gt_geom = gt_geometry(G, epsilon=eps)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps)
graph_geom = geodesic.Geodesic.from_graph(G, t=eps / 4.0)

fn = jax.jit(callback) if jit else callback
gt_out = fn(gt_geom)
Expand Down Expand Up @@ -240,3 +253,15 @@ def callback(geom: geodesic.Geodesic) -> float:
expected = callback(geom__finite_right) - callback(geom__finite_left)
actual = 2 * jnp.vdot(v_w, grad_sl)
np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("normalize", [False, True])
@pytest.mark.parametrize("t", [5, 10, 50])
@pytest.mark.parametrize("order", [20, 30, 40])
def test_heat_approx(self, normalize: bool, t: float, order: int):
G = random_graph(20, p=0.5)
exact = exact_heat_kernel(G, normalize=normalize, t=t)
geom = geodesic.Geodesic.from_graph(
G, t=t, order=order, normalize=normalize
)
approx = geom.apply_kernel(jnp.eye(G.shape[0]))
np.testing.assert_allclose(exact, approx, rtol=1e-1, atol=1e-1)