diff --git a/deeptrack/backend/array_api_compat_ext/__init__.py b/deeptrack/backend/array_api_compat_ext/__init__.py index e67f772af..12e2309f1 100644 --- a/deeptrack/backend/array_api_compat_ext/__init__.py +++ b/deeptrack/backend/array_api_compat_ext/__init__.py @@ -1,8 +1,10 @@ -from array_api_compat import torch as apctorch -from deeptrack.backend.array_api_compat_ext.torch import random +from deeptrack import TORCH_AVAILABLE +if TORCH_AVAILABLE: + from array_api_compat import torch as apctorch + from deeptrack.backend.array_api_compat_ext.torch import random # NumPy and PyTorch random functions are incompatible with each other. # The current array_api_compat module does not fix this incompatibility. # So we implement our own patch, which implements a numpy-compatible interface # for the torch random functions. -apctorch.random = random + apctorch.random = random diff --git a/deeptrack/backend/array_api_compat_ext/torch/__init__.py b/deeptrack/backend/array_api_compat_ext/torch/__init__.py index bc9dc47f2..2d2c1dc3d 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/__init__.py +++ b/deeptrack/backend/array_api_compat_ext/torch/__init__.py @@ -1,4 +1,7 @@ -from deeptrack.backend.array_api_compat_ext.torch import random +from deeptrack import TORCH_AVAILABLE + +if TORCH_AVAILABLE: + from deeptrack.backend.array_api_compat_ext.torch import random __all__ = ["random"] diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 2ec864f17..5a75825d3 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -1,12 +1,39 @@ +"""Compatibility module for Numpy functions + +This module contains helper functions that use the same syntax as +the equivalent numpy.random functions. All functions return +torch.Tensors when used and accept optional `dtype` and `device` +arguments that default to `float32` and `cpu`. + + +Examples +-------- +Sample the `beta` distribution: + +>>> from torch import cuda, float16 + +>>> if cuda.is_available(): +... print(beta(1, 2, dtype=torch.float16, device="cuda")) + +tensor(0.3315, device='cuda:0', dtype=torch.float16) + + +""" + from __future__ import annotations +from deeptrack import TORCH_AVAILABLE -import torch +import numpy as np +if TORCH_AVAILABLE: + import torch + __all__ = [ "rand", "random", "random_sample", "randn", + "standard_normal", "beta", "binomial", "choice", @@ -16,93 +43,222 @@ "uniform", "normal", "poisson", + "gamma", ] -def rand(*args: int) -> torch.Tensor: - return torch.rand(*args) +def rand( + *args: int, + dtype: torch.dtype=torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.rand(*args, dtype=dtype, device=device) + + +def random( + size: tuple[int, ...] | None = None, + dtype: torch.dtype=torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return ( + torch.rand(*size, dtype=dtype, device=device) + if size else torch.rand(dtype=dtype, device=device) + ) -def random(size: tuple[int, ...] | None = None) -> torch.Tensor: - return torch.rand(*size) if size else torch.rand() +def random_sample( + size: tuple[int, ...] | None = None, + dtype: torch.dtype=torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return ( + torch.rand(*size, dtype=dtype, device=device) + if size else torch.rand(dtype=dtype, device=device) + ) -def random_sample(size: tuple[int, ...] | None = None) -> torch.Tensor: - return torch.rand(*size) if size else torch.rand() +def randn( + *args: int, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.randn(*args, dtype=dtype, device=device) -def randn(*args: int) -> torch.Tensor: - return torch.randn(*args) +def standard_normal( + *args: int, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.randn(*args, dtype=dtype, device=device) def beta( a: float, b: float, - size: tuple[int, ...] | None = None, + size: int | tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - raise NotImplementedError("the beta distribution is not implemented in torch") + return ( + torch.tensor( + np.random.beta(a, b, size), dtype=dtype, device=device + ) + ) def binomial( n: int, p: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.bernoulli(torch.full(size, p)) + return ( + torch.tensor( + np.random.binomial(n, p, size), dtype=dtype, device=device + ) + ) def choice( - a: torch.Tensor, + a: torch.Tensor | np.ndarray, size: tuple[int, ...] | None = None, replace: bool = True, p: torch.Tensor | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - raise NotImplementedError( - "the choice function is not implemented in torch" + a_numpy = a.cpu().numpy() + p_numpy = p.cpu().numpy() if p is not None else None + return ( + torch.tensor( + np.random_choice( + a_numpy, size=size, replace=replace, p=p_numpy + ), dtype=dtype, device=device + ) ) - + def multinomial( n: int, pvals: torch.Tensor, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.multinomial(pvals, n, size) + return torch.multinomial(pvals, n, size, dtype=dtype, device=device) def randint( low: int, high: int, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.randint(low, high, size) - + return torch.randint(low, high, size, dtype=dtype, device=device) + def shuffle(x: torch.Tensor) -> torch.Tensor: - return x[torch.randperm(x.shape[0])] + return x[torch.randperm(x.shape[0], device=x.device)] def uniform( low: float, high: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.rand(*size) * (high - low) + low + return torch.rand(*size, dtype=dtype, device=device) * (high - low) + low def normal( loc: float, scale: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.randn(*size) * scale + loc + return torch.randn(*size, dtype=dtype, device=device) * scale + loc def poisson( lam: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.poisson(torch.full(size, lam, dtype=dtype, device=device)) + + +def gamma( + shape: float | torch.Tensor, + scale: float | torch.Tensor = 1.0, + size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + shape = torch.as_tensor(shape, dtype=dtype, device=device) + scale = torch.as_tensor(scale, dtype=dtype, device=device) + if size is not None: + shape = shape.expand(size) + scale = scale.expand(size) + return torch.distributions.Gamma(shape, scale).sample() + + +def exponential( + scale: float | torch.Tensor = 1.0, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + rate = torch.as_tensor(1.0/scale, dtype=dtype, device=device) + if size is None: + return torch.distributions.Exponential(rate).sample() + return torch.distributions.Exponential(rate).sample(size) + + +def multivariate_normal( + mean: torch.Tensor, + cov: torch.Tensor, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", ) -> torch.Tensor: - return torch.poisson(torch.full(size, lam)) + mean = mean.to(dtype=dtype, device=device) + cov = cov.to(dtype=dtype, device=device) + if size is None: + return torch.distributions.MultivariateNormal( + mean, covariance_matrix=cov).sample() + return torch.distributions.MultivariateNormal( + mean, covariance_matrix=cov).sample(size) +def geometric( + p: float | torch.Tensor, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + p = torch.as_tensor(p, dtype=torch.float32, device=device) + if size is None: + return torch.distributions.Geometric( + probs=p).sample().to(dtype) + return torch.distributions.Geometric( + probs=p).sample(size).to(dtype) + + +def dirichlet( + alpha: torch.Tensor, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + alpha = alpha.to(dtype=dtype, device=device) + if size is None: + return torch.distributions.Dirichlet(alpha).sample() + return torch.distributions.Dirichlet(alpha).sample(size) + # TODO: implement the rest of the functions as they are needed diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py new file mode 100644 index 000000000..4f461956a --- /dev/null +++ b/deeptrack/tests/backend/test_random.py @@ -0,0 +1,47 @@ +import unittest + +import numpy as np + +from deeptrack.backend import TORCH_AVAILABLE +from deeptrack.backend.array_api_compat_ext.torch import random + + +""" +TODO: Implement tests for all of these functions to start with. + "rand", + "random", + "random_sample", + "randn", + "beta", + "binomial", + "choice", + "multinomial", + "randint", + "shuffle", + "uniform", + "normal", + "poisson", + +""" +if TORCH_AVAILABLE: + import torch + class TestRandom(unittest.TestCase): + + def test_rand(self): + shapes = [(2, ), (3, 4)] + dtypes = [torch.float32, torch.float64] + devices = [torch.device("cpu"), "cpu"] + + for shape, dtype, device in zip(shapes, dtypes, devices): + + expected = np.random.rand(*shape) + generated = random.rand(*shape, dtype=dtype, device=device) + self.assertEqual(generated.shape, expected.shape) + self.assertEqual(generated.dtype, dtype) + + a = random.rand(100, dtype=torch.float32, device="cpu") + b = np.random.rand(100) + self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand + + if __name__ == "__main__": + unittest.main() diff --git a/deeptrack/utils.py b/deeptrack/utils.py index 10051b8e3..6b6449339 100644 --- a/deeptrack/utils.py +++ b/deeptrack/utils.py @@ -69,6 +69,7 @@ def safe_call( Examples -------- + Check if a method exists in an object: >>> from deeptrack.utils import hasmethod