Skip to content

Commit

Permalink
Fix/jax prng deprecation (#455)
Browse files Browse the repository at this point in the history
* Remove `utils.is_jax_array`

* Remove old jax util

* Remove `jax.random.PRNGKeyArray` type

* Remove jax.random.PRNGKeyArray warning from other libs in tests
  • Loading branch information
michalk8 authored Nov 7, 2023
1 parent c7ca827 commit 0a8cf0c
Show file tree
Hide file tree
Showing 69 changed files with 270 additions and 357 deletions.
4 changes: 1 addition & 3 deletions docs/tutorials/GWLRSinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@
},
"outputs": [],
"source": [
"def create_points(\n",
" rng: jax.random.PRNGKeyArray, n: int, m: int, d1: int, d2: int\n",
"):\n",
"def create_points(rng: jax.Array, n: int, m: int, d1: int, d2: int):\n",
" rngs = jax.random.split(rng, 5)\n",
" x = jax.random.uniform(rngs[0], (n, d1))\n",
" y = jax.random.uniform(rngs[1], (m, d2))\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
" noise: float = 0.01\n",
" scale: float = 1.0\n",
" batch_size: int = 1024\n",
" rng: Optional[jax.random.PRNGKeyArray] = (None,)\n",
" rng: Optional[jax.Array] = (None,)\n",
"\n",
" def __iter__(self) -> Iterator[jnp.ndarray]:\n",
" \"\"\"Random sample generator from Gaussian mixture.\n",
Expand Down Expand Up @@ -151,7 +151,7 @@
" target_kwargs: Mapping[str, Any] = MappingProxyType({}),\n",
" train_batch_size: int = 256,\n",
" valid_batch_size: int = 256,\n",
" rng: Optional[jax.random.PRNGKeyArray] = None,\n",
" rng: Optional[jax.Array] = None,\n",
") -> Tuple[dataset.Dataset, dataset.Dataset, int]:\n",
" \"\"\"Samplers from ``SklearnDistribution``.\"\"\"\n",
" rng = jax.random.PRNGKey(0) if rng is None else rng\n",
Expand Down Expand Up @@ -202,7 +202,7 @@
" num_points: Optional[int] = None,\n",
" title: Optional[str] = None,\n",
" figsize: Tuple[int, int] = (8, 6),\n",
" rng: Optional[jax.random.PRNGKeyArray] = None,\n",
" rng: Optional[jax.Array] = None,\n",
"):\n",
" \"\"\"Plot samples from the source and target measures.\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ markers = [
"cpu: Mark tests as CPU only.",
"fast: Mark tests as fast.",
]
filterwarnings = [
"ignore:jax.random.KeyArray is deprecated:DeprecationWarning",
]

[tool.coverage.run]
branch = true
Expand Down
6 changes: 3 additions & 3 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np

from ott import utils
from ott.geometry import epsilon_scheduler
Expand Down Expand Up @@ -200,8 +201,7 @@ def is_symmetric(self) -> bool:
@property
def inv_scale_cost(self) -> float:
"""Compute and return inverse of scaling factor for cost matrix."""
if isinstance(self._scale_cost,
(int, float)) or utils.is_jax_array(self._scale_cost):
if isinstance(self._scale_cost, (int, float, np.number, jax.Array)):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == "max_cost":
Expand Down Expand Up @@ -625,7 +625,7 @@ def to_LRCGeometry(
self,
rank: int = 0,
tol: float = 1e-2,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
scale: float = 1.
) -> "low_rank.LRCGeometry":
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
Expand Down
6 changes: 2 additions & 4 deletions src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import geometry

__all__ = ["LRCGeometry"]
Expand Down Expand Up @@ -108,8 +107,7 @@ def is_symmetric(self) -> bool: # noqa: D102

@property
def inv_scale_cost(self) -> float: # noqa: D102
if isinstance(self._scale_cost,
(int, float)) or utils.is_jax_array(self._scale_cost):
if isinstance(self._scale_cost, (int, float, jax.Array)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == "max_bound":
Expand Down Expand Up @@ -231,7 +229,7 @@ def to_LRCGeometry(
self,
rank: int = 0,
tol: float = 1e-2,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
scale: float = 1.0,
) -> "LRCGeometry":
"""Return self."""
Expand Down
4 changes: 1 addition & 3 deletions src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import costs, geometry, low_rank
from ott.math import utils as mu

Expand Down Expand Up @@ -142,8 +141,7 @@ def cost_rank(self) -> int: # noqa: D102

@property
def inv_scale_cost(self) -> float: # noqa: D102
if isinstance(self._scale_cost,
(int, float)) or utils.is_jax_array(self._scale_cost):
if isinstance(self._scale_cost, (int, float, jax.Array)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == "max_cost":
Expand Down
16 changes: 8 additions & 8 deletions src/ott/initializers/linear/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def init_dual_a(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
"""Initialize Sinkhorn potential/scaling f_u.
Expand All @@ -54,7 +54,7 @@ def init_dual_b(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
"""Initialize Sinkhorn potential/scaling g_v.
Expand All @@ -73,7 +73,7 @@ def __call__(
a: Optional[jnp.ndarray],
b: Optional[jnp.ndarray],
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Initialize Sinkhorn potentials/scalings f_u and g_v.
Expand Down Expand Up @@ -128,7 +128,7 @@ def init_dual_a( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
del rng
return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a)
Expand All @@ -137,7 +137,7 @@ def init_dual_b( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
del rng
return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b)
Expand All @@ -158,7 +158,7 @@ def init_dual_a( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
# import Gaussian here due to circular imports
from ott.tools.gaussian_mixture import gaussian
Expand Down Expand Up @@ -245,7 +245,7 @@ def init_dual_a(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
init_f: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""Apply DualSort algorithm.
Expand Down Expand Up @@ -324,7 +324,7 @@ def init_dual_a( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
from ott.solvers import linear

Expand Down
30 changes: 15 additions & 15 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, rank: int, **kwargs: Any):
def init_q(
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -88,7 +88,7 @@ def init_q(
def init_r(
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -109,7 +109,7 @@ def init_r(
def init_g(
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
**kwargs: Any,
) -> jnp.ndarray:
"""Initialize the low-rank factor :math:`g`.
Expand Down Expand Up @@ -169,7 +169,7 @@ def __call__(
r: Optional[jnp.ndarray] = None,
g: Optional[jnp.ndarray] = None,
*,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
**kwargs: Any
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Initialize the factors :math:`Q`, :math:`R` and :math:`g`.
Expand Down Expand Up @@ -232,7 +232,7 @@ class RandomInitializer(LRInitializer):
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -245,7 +245,7 @@ def init_q( # noqa: D102
def init_r( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -258,7 +258,7 @@ def init_r( # noqa: D102
def init_g( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs
Expand Down Expand Up @@ -305,7 +305,7 @@ def _compute_factor(
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -316,7 +316,7 @@ def init_q( # noqa: D102
def init_r( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -327,7 +327,7 @@ def init_r( # noqa: D102
def init_g( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
**kwargs: Any,
) -> jnp.ndarray:
del rng, kwargs
Expand Down Expand Up @@ -376,7 +376,7 @@ def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray:
def _compute_factor(
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
Expand Down Expand Up @@ -418,7 +418,7 @@ def _compute_factor(
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -430,7 +430,7 @@ def init_q( # noqa: D102
def init_r( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -442,7 +442,7 @@ def init_r( # noqa: D102
def init_g( # noqa: D102
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
**kwargs: Any,
) -> jnp.ndarray:
del rng, kwargs
Expand Down Expand Up @@ -511,7 +511,7 @@ class State(NamedTuple): # noqa: D106
def _compute_factor(
self,
ot_prob: Problem_t,
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
Expand Down
4 changes: 2 additions & 2 deletions src/ott/initializers/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
meta_model: Optional[nn.Module] = None,
opt: Optional[optax.GradientTransformation
] = optax.adam(learning_rate=1e-3), # noqa: B008
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
state: Optional[train_state.TrainState] = None
):
self.geom = geom
Expand Down Expand Up @@ -145,7 +145,7 @@ def init_dual_a( # noqa: D102
self,
ot_prob: "linear_problem.LinearProblem",
lse_mode: bool,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
del rng
# Detect if the problem is batched.
Expand Down
4 changes: 2 additions & 2 deletions src/ott/problems/nn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class GaussianMixture:
"""
name: Name_t
batch_size: int
init_rng: jax.random.PRNGKeyArray
init_rng: jax.Array
scale: float = 5.0
std: float = 0.5

Expand Down Expand Up @@ -110,7 +110,7 @@ def create_gaussian_mixture_samplers(
name_target: Name_t,
train_batch_size: int = 2048,
valid_batch_size: int = 2048,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> Tuple[Dataset, Dataset, int]:
"""Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`.
Expand Down
2 changes: 1 addition & 1 deletion src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def convertible(geom: geometry.Geometry) -> bool:

def to_low_rank(
self,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> "QuadraticProblem":
"""Convert geometries to low-rank.
Expand Down
6 changes: 3 additions & 3 deletions src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __call__( # noqa: D102
bar_prob: barycenter_problem.FreeBarycenterProblem,
bar_size: int = 100,
x_init: Optional[jnp.ndarray] = None,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> FreeBarycenterState:
# TODO(michalk8): no reason for iterations to be outside this class
rng = utils.default_prng_key(rng)
Expand All @@ -141,7 +141,7 @@ def init_state(
bar_prob: barycenter_problem.FreeBarycenterProblem,
bar_size: int,
x_init: Optional[jnp.ndarray] = None,
rng: Optional[jax.random.PRNGKeyArray] = None,
rng: Optional[jax.Array] = None,
) -> FreeBarycenterState:
"""Initialize the state of the Wasserstein barycenter iterations.
Expand Down Expand Up @@ -196,7 +196,7 @@ def output_from_state( # noqa: D102
def iterations(
solver: FreeWassersteinBarycenter, bar_size: int,
bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray,
rng: jax.random.PRNGKeyArray
rng: jax.Array
) -> FreeBarycenterState:
"""Jittable Wasserstein barycenter outer loop."""

Expand Down
Loading

0 comments on commit 0a8cf0c

Please sign in to comment.