Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix not being able to jit the soft quantile #408

Merged
merged 2 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -333,8 +333,8 @@ def topk_mask(

def quantile(
inputs: jnp.ndarray,
q: jnp.ndarray,
axis: int = -1,
q: Union[float, jnp.ndarray],
axis: Union[int, Tuple[int, ...]] = -1,
weight: Optional[Union[float, jnp.ndarray]] = None,
**kwargs: Any,
) -> jnp.ndarray:
Expand Down Expand Up @@ -440,11 +440,11 @@ def _quantile(
out = 1.0 / weights * ot.apply(jnp.squeeze(inputs), axis=0)

# Recover odd indices corresponding to the desired quantiles.
odds = jnp.concatenate([
jnp.zeros((num_quantiles + 1, 1), dtype=bool),
jnp.ones((num_quantiles + 1, 1), dtype=bool)
odds = np.concatenate([
np.zeros((num_quantiles + 1, 1), dtype=bool),
np.ones((num_quantiles + 1, 1), dtype=bool)
],
axis=1).ravel()[:-1]
axis=1).ravel()[:-1]
return out[odds][idx]

return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs)
Expand Down
66 changes: 45 additions & 21 deletions tests/tools/soft_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int):
np.testing.assert_array_equal(xs.shape, expected_shape)
np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True)

@pytest.mark.fast.with_args("axis", [0, 1], only_fast0=0)
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray):
@pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0)
def test_ranks(self, axis, rng: jax.random.PRNGKeyArray, jit: bool):
rng1, rng2 = jax.random.split(rng, 2)
num_targets = 13
x = jax.random.uniform(rng1, (8, 5, 2))

expected_ranks = jnp.argsort(
jnp.argsort(x, axis=axis), axis=axis
).astype(float)
# Define a custom version of ranks suited to recover ranks that are
# close to true ranks. This requires notably small epsilon and large # iter.
my_ranks = functools.partial(
Expand All @@ -101,10 +103,11 @@ def test_ranks(self, axis, rng: jax.random.PRNGKeyArray):
axis=axis,
max_iterations=5000
)
expected_ranks = jnp.argsort(
jnp.argsort(x, axis=axis), axis=axis
).astype(float)
if jit:
my_ranks = jax.jit(my_ranks, static_argnames="num_targets")

ranks = my_ranks(x)

np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

Expand All @@ -118,8 +121,12 @@ def test_ranks(self, axis, rng: jax.random.PRNGKeyArray):
np.testing.assert_array_equal(x.shape, ranks.shape)
np.testing.assert_allclose(ranks, expected_ranks, atol=0.3, rtol=0.1)

@pytest.mark.fast.with_args("axis", [0, 1], only_fast=0)
def test_topk_mask(self, axis, rng: jax.random.PRNGKeyArray):
@pytest.mark.fast.with_args("axis,jit", [(0, False), (1, True)], only_fast=0)
def test_topk_mask(self, axis, rng: jax.random.PRNGKeyArray, jit: bool):

def boolean_topk_mask(u, k):
return u >= jnp.flip(jax.numpy.sort(u))[k - 1]

k = 3
x = jax.random.uniform(rng, (13, 7, 1))
my_topk_mask = functools.partial(
Expand All @@ -129,13 +136,12 @@ def test_topk_mask(self, axis, rng: jax.random.PRNGKeyArray):
max_iterations=15000, # needed to recover a sharp mask given close ties
axis=axis
)
mask = my_topk_mask(x, k=k, axis=axis)
if jit:
my_topk_mask = jax.jit(my_topk_mask, static_argnames=("k", "axis"))

def boolean_topk_mask(u, k):
return u >= jnp.flip(jax.numpy.sort(u))[k - 1]
mask = my_topk_mask(x, k=k, axis=axis)

expected_mask = soft_sort.apply_on_axis(boolean_topk_mask, x, axis, k)

np.testing.assert_array_equal(x.shape, mask.shape)
np.testing.assert_allclose(mask, expected_mask, atol=0.01, rtol=0.1)

Expand All @@ -160,21 +166,34 @@ def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray):
)

@pytest.mark.fast()
def test_quantiles(self, rng: jax.random.PRNGKeyArray):
@pytest.mark.parametrize("jit", [False, True])
def test_quantiles(self, rng: jax.random.PRNGKeyArray, jit: bool):
inputs = jax.random.uniform(rng, (200, 2, 3))
q = jnp.array([.1, .8, .4])
m1 = soft_sort.quantile(inputs, q=q, weight=None, axis=0)
quantile_fn = soft_sort.quantile
if jit:
quantile_fn = jax.jit(quantile_fn, static_argnames="axis")

m1 = quantile_fn(inputs, q=q, weight=None, axis=0)

np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2)

def test_soft_quantile_normalization(self, rng: jax.random.PRNGKeyArray):
@pytest.mark.parametrize("jit", [False, True])
def test_soft_quantile_normalization(
self, rng: jax.random.PRNGKeyArray, jit: bool
):
rngs = jax.random.split(rng, 2)
x = jax.random.uniform(rngs[0], shape=(100,))
mu, sigma = 2.0, 1.2
y = mu + sigma * jax.random.normal(rng, shape=(48,))
mu_target, sigma_target = y.mean(), y.std()
qn = soft_sort.quantile_normalization(x, jnp.sort(y), epsilon=1e-4)
mu_transform, sigma_transform = qn.mean(), qn.std()
quantize_fn = soft_sort.quantile_normalization
if jit:
quantize_fn = jax.jit(quantize_fn)

qn = quantize_fn(x, jnp.sort(y), epsilon=1e-4)

mu_transform, sigma_transform = qn.mean(), qn.std()
np.testing.assert_allclose([mu_transform, sigma_transform],
[mu_target, sigma_target],
rtol=0.05)
Expand All @@ -196,20 +215,25 @@ def test_sort_with(self, rng: jax.random.PRNGKeyArray):
np.testing.assert_allclose(output, inputs[-k:], atol=0.05)

@pytest.mark.fast()
def test_quantize(self):
@pytest.mark.parametrize("jit", [False, True])
def test_quantize(self, jit: bool):
n = 100
inputs = jnp.linspace(0.0, 1.0, n)[..., None]
q = soft_sort.quantize(inputs, num_levels=4, axis=0, epsilon=1e-4)
quantize_fn = soft_sort.quantize
if jit:
quantize_fn = jax.jit(quantize_fn, static_argnames=("num_levels", "axis"))

q = quantize_fn(inputs, num_levels=4, axis=0, epsilon=1e-4)

delta = jnp.abs(q - jnp.array([0.12, 0.34, 0.64, 0.86]))
min_distances = jnp.min(delta, axis=1)

np.testing.assert_allclose(min_distances, min_distances, atol=0.05)

@pytest.mark.parametrize("implicit", [False, True])
def test_soft_sort_jacobian(
self, rng: jax.random.PRNGKeyArray, implicit: bool
):
## Add a ridge when using JAX solvers.
# Add a ridge when using JAX solvers.
try:
from ott.solvers.linear import lineax_implicit # noqa: F401
solver_kwargs = {}
Expand Down
Loading