Skip to content

Torch/numpy random functions #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 53 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2c81251
u
giovannivolpe May 29, 2025
abdfc82
Update _config.py
giovannivolpe May 29, 2025
b681268
Update core.py
giovannivolpe May 29, 2025
a737f83
Update test_core.py
giovannivolpe May 29, 2025
64fbdda
Update _config.py
giovannivolpe May 29, 2025
99d180d
Update core.py
giovannivolpe May 29, 2025
2f4d251
Update _config.py
giovannivolpe May 29, 2025
3ae3962
Update core.py
giovannivolpe May 29, 2025
740f962
Update core.py
giovannivolpe May 29, 2025
8c28553
Update core.py
giovannivolpe May 29, 2025
7c152bf
Update test_core.py
giovannivolpe May 29, 2025
aded038
Update core.py
giovannivolpe May 29, 2025
2fe8dd1
Update test_core.py
giovannivolpe May 29, 2025
203709d
Update core.py
giovannivolpe May 29, 2025
fa3ba34
Update core.py
giovannivolpe May 29, 2025
bf82050
Update core.py
giovannivolpe May 29, 2025
8411100
Update core.py
giovannivolpe May 29, 2025
530ad75
Update DTAT399F_backend._config.ipynb
giovannivolpe May 29, 2025
0113458
Update DTAT399A_backend.core.ipynb
giovannivolpe May 29, 2025
1801ec9
Update pint_definition.py
giovannivolpe May 29, 2025
ed99efa
u
giovannivolpe May 29, 2025
22264a2
Update utils.py
giovannivolpe May 29, 2025
ececad7
Update utils.py
giovannivolpe May 29, 2025
d9d9d00
Update utils.py
giovannivolpe May 29, 2025
28d29f0
Update utils.py
giovannivolpe May 29, 2025
2b92865
Update test_utils.py
giovannivolpe May 29, 2025
04d965b
Implemented torch.rand()
Pwhsky May 30, 2025
ab09fb4
Unit test for random.rand()
Pwhsky May 30, 2025
f8017df
typo
Pwhsky May 30, 2025
6494c44
added torch to requirements
Pwhsky May 30, 2025
b24081e
test_random: mean test and numpy comparison
Pwhsky May 30, 2025
c99c22e
Import torch
Pwhsky May 30, 2025
39ef3b4
Update test_random.py
Pwhsky May 30, 2025
0193668
implemented random.beta()
Pwhsky May 30, 2025
f8c14b5
Update test_random.py
Pwhsky Jun 5, 2025
ef3a53e
Merge branch 'develop' into AL/torch/random
Pwhsky Jun 27, 2025
6c1eca9
binomial added
Pwhsky Jun 28, 2025
2bab7cc
Several functions implemented.
Pwhsky Jun 30, 2025
f0b3179
Update test_random.py
Pwhsky Jun 30, 2025
6a646d5
Implemented Gamma distribution
Pwhsky Jul 1, 2025
5782d55
added exponential, geometric, multivar, dirichlet
Pwhsky Jul 1, 2025
1b7307f
syntax
Pwhsky Jul 1, 2025
cefbc93
docs
Pwhsky Jul 1, 2025
063b44a
formatting, linebreaks
Pwhsky Jul 1, 2025
9aa82b8
docs
Pwhsky Jul 1, 2025
0a6d561
removed torch from requirements
Pwhsky Jul 3, 2025
0e8aa7c
check torch availability in tests
Pwhsky Jul 3, 2025
296ca89
Update test_random.py
Pwhsky Jul 3, 2025
fa4f6cd
Update __init__.py
Pwhsky Jul 3, 2025
fa30c97
Update __init__.py
Pwhsky Jul 3, 2025
d2a313f
Update __init__.py
Pwhsky Jul 3, 2025
328a69b
Update random.py
Pwhsky Jul 3, 2025
a4fb70a
Update __init__.py
Pwhsky Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions deeptrack/backend/array_api_compat_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from array_api_compat import torch as apctorch
from deeptrack.backend.array_api_compat_ext.torch import random
from deeptrack import TORCH_AVAILABLE
if TORCH_AVAILABLE:
from array_api_compat import torch as apctorch
from deeptrack.backend.array_api_compat_ext.torch import random

# NumPy and PyTorch random functions are incompatible with each other.
# The current array_api_compat module does not fix this incompatibility.
# So we implement our own patch, which implements a numpy-compatible interface
# for the torch random functions.
apctorch.random = random
apctorch.random = random
5 changes: 4 additions & 1 deletion deeptrack/backend/array_api_compat_ext/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from deeptrack.backend.array_api_compat_ext.torch import random
from deeptrack import TORCH_AVAILABLE

if TORCH_AVAILABLE:
from deeptrack.backend.array_api_compat_ext.torch import random


__all__ = ["random"]
202 changes: 179 additions & 23 deletions deeptrack/backend/array_api_compat_ext/torch/random.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
"""Compatibility module for Numpy functions

This module contains helper functions that use the same syntax as
the equivalent numpy.random functions. All functions return
torch.Tensors when used and accept optional `dtype` and `device`
arguments that default to `float32` and `cpu`.


Examples
--------
Sample the `beta` distribution:

>>> from torch import cuda, float16

>>> if cuda.is_available():
... print(beta(1, 2, dtype=torch.float16, device="cuda"))

tensor(0.3315, device='cuda:0', dtype=torch.float16)


"""

from __future__ import annotations
from deeptrack import TORCH_AVAILABLE

import torch
import numpy as np

if TORCH_AVAILABLE:
import torch

__all__ = [
"rand",
"random",
"random_sample",
"randn",
"standard_normal",
"beta",
"binomial",
"choice",
Expand All @@ -16,93 +43,222 @@
"uniform",
"normal",
"poisson",
"gamma",
]


def rand(*args: int) -> torch.Tensor:
return torch.rand(*args)
def rand(
*args: int,
dtype: torch.dtype=torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.rand(*args, dtype=dtype, device=device)


def random(
size: tuple[int, ...] | None = None,
dtype: torch.dtype=torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return (
torch.rand(*size, dtype=dtype, device=device)
if size else torch.rand(dtype=dtype, device=device)
)


def random(size: tuple[int, ...] | None = None) -> torch.Tensor:
return torch.rand(*size) if size else torch.rand()
def random_sample(
size: tuple[int, ...] | None = None,
dtype: torch.dtype=torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return (
torch.rand(*size, dtype=dtype, device=device)
if size else torch.rand(dtype=dtype, device=device)
)


def random_sample(size: tuple[int, ...] | None = None) -> torch.Tensor:
return torch.rand(*size) if size else torch.rand()
def randn(
*args: int,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.randn(*args, dtype=dtype, device=device)


def randn(*args: int) -> torch.Tensor:
return torch.randn(*args)
def standard_normal(
*args: int,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.randn(*args, dtype=dtype, device=device)


def beta(
a: float,
b: float,
size: tuple[int, ...] | None = None,
size: int | tuple[int, ...] = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
raise NotImplementedError("the beta distribution is not implemented in torch")
return (
torch.tensor(
np.random.beta(a, b, size), dtype=dtype, device=device
)
)


def binomial(
n: int,
p: float,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.bernoulli(torch.full(size, p))
return (
torch.tensor(
np.random.binomial(n, p, size), dtype=dtype, device=device
)
)


def choice(
a: torch.Tensor,
a: torch.Tensor | np.ndarray,
size: tuple[int, ...] | None = None,
replace: bool = True,
p: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
raise NotImplementedError(
"the choice function is not implemented in torch"
a_numpy = a.cpu().numpy()
p_numpy = p.cpu().numpy() if p is not None else None
return (
torch.tensor(
np.random_choice(
a_numpy, size=size, replace=replace, p=p_numpy
), dtype=dtype, device=device
)
)


def multinomial(
n: int,
pvals: torch.Tensor,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.multinomial(pvals, n, size)
return torch.multinomial(pvals, n, size, dtype=dtype, device=device)


def randint(
low: int,
high: int,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.randint(low, high, size)

return torch.randint(low, high, size, dtype=dtype, device=device)

def shuffle(x: torch.Tensor) -> torch.Tensor:
return x[torch.randperm(x.shape[0])]
return x[torch.randperm(x.shape[0], device=x.device)]


def uniform(
low: float,
high: float,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.rand(*size) * (high - low) + low
return torch.rand(*size, dtype=dtype, device=device) * (high - low) + low


def normal(
loc: float,
scale: float,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.randn(*size) * scale + loc
return torch.randn(*size, dtype=dtype, device=device) * scale + loc


def poisson(
lam: float,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
return torch.poisson(torch.full(size, lam, dtype=dtype, device=device))


def gamma(
shape: float | torch.Tensor,
scale: float | torch.Tensor = 1.0,
size: tuple[int, ...] | None = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = torch.device("cpu"),
) -> torch.Tensor:
shape = torch.as_tensor(shape, dtype=dtype, device=device)
scale = torch.as_tensor(scale, dtype=dtype, device=device)
if size is not None:
shape = shape.expand(size)
scale = scale.expand(size)
return torch.distributions.Gamma(shape, scale).sample()


def exponential(
scale: float | torch.Tensor = 1.0,
size: tuple[int, ...] = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = "cpu",
) -> torch.Tensor:
rate = torch.as_tensor(1.0/scale, dtype=dtype, device=device)
if size is None:
return torch.distributions.Exponential(rate).sample()
return torch.distributions.Exponential(rate).sample(size)


def multivariate_normal(
mean: torch.Tensor,
cov: torch.Tensor,
size: tuple[int, ...] = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = "cpu",
) -> torch.Tensor:
return torch.poisson(torch.full(size, lam))
mean = mean.to(dtype=dtype, device=device)
cov = cov.to(dtype=dtype, device=device)
if size is None:
return torch.distributions.MultivariateNormal(
mean, covariance_matrix=cov).sample()
return torch.distributions.MultivariateNormal(
mean, covariance_matrix=cov).sample(size)


def geometric(
p: float | torch.Tensor,
size: tuple[int, ...] = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = "cpu",
) -> torch.Tensor:
p = torch.as_tensor(p, dtype=torch.float32, device=device)
if size is None:
return torch.distributions.Geometric(
probs=p).sample().to(dtype)
return torch.distributions.Geometric(
probs=p).sample(size).to(dtype)


def dirichlet(
alpha: torch.Tensor,
size: tuple[int, ...] = None,
dtype: torch.dtype = torch.float32,
device: torch.device | str = "cpu",
) -> torch.Tensor:
alpha = alpha.to(dtype=dtype, device=device)
if size is None:
return torch.distributions.Dirichlet(alpha).sample()
return torch.distributions.Dirichlet(alpha).sample(size)

# TODO: implement the rest of the functions as they are needed
47 changes: 47 additions & 0 deletions deeptrack/tests/backend/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest

import numpy as np

from deeptrack.backend import TORCH_AVAILABLE
from deeptrack.backend.array_api_compat_ext.torch import random


"""
TODO: Implement tests for all of these functions to start with.
"rand",
"random",
"random_sample",
"randn",
"beta",
"binomial",
"choice",
"multinomial",
"randint",
"shuffle",
"uniform",
"normal",
"poisson",

"""
if TORCH_AVAILABLE:
import torch
class TestRandom(unittest.TestCase):

def test_rand(self):
shapes = [(2, ), (3, 4)]
dtypes = [torch.float32, torch.float64]
devices = [torch.device("cpu"), "cpu"]

for shape, dtype, device in zip(shapes, dtypes, devices):

expected = np.random.rand(*shape)
generated = random.rand(*shape, dtype=dtype, device=device)
self.assertEqual(generated.shape, expected.shape)
self.assertEqual(generated.dtype, dtype)

a = random.rand(100, dtype=torch.float32, device="cpu")
b = np.random.rand(100)
self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand

if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions deeptrack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def safe_call(

Examples
--------

Check if a method exists in an object:

>>> from deeptrack.utils import hasmethod
Expand Down
Loading