diff --git a/src/lib.rs b/src/lib.rs index ef1109b..057d995 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,7 @@ //! - [`Beta`] distribution //! - [`Triangular`] distribution //! - Multivariate probability distributions -//! - [`Dirichlet`] distribution +//! - [`multi::Dirichlet`] distribution //! - [`UnitSphere`] distribution //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution @@ -100,8 +100,6 @@ pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError}; -#[cfg(feature = "alloc")] -pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::exponential::{Error as ExpError, Exp, Exp1}; pub use self::fisher_f::{Error as FisherFError, FisherF}; pub use self::frechet::{Error as FrechetError, Frechet}; @@ -130,6 +128,8 @@ pub use student_t::StudentT; pub use num_traits; +#[cfg(feature = "alloc")] +pub mod multi; #[cfg(feature = "alloc")] pub mod weighted; @@ -188,7 +188,6 @@ mod beta; mod binomial; mod cauchy; mod chi_squared; -mod dirichlet; mod exponential; mod fisher_f; mod frechet; diff --git a/src/dirichlet.rs b/src/multi/dirichlet.rs similarity index 92% rename from src/dirichlet.rs rename to src/multi/dirichlet.rs index ac17fa2..558f64e 100644 --- a/src/dirichlet.rs +++ b/src/multi/dirichlet.rs @@ -10,7 +10,7 @@ //! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. #![cfg(feature = "alloc")] -use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; +use crate::{multi::MultiDistribution, Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; use core::fmt; use num_traits::{Float, NumCast}; use rand::Rng; @@ -68,26 +68,27 @@ where } } -impl Distribution<[F; N]> for DirichletFromGamma +impl MultiDistribution for DirichletFromGamma where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; + fn sample_len(&self) -> usize { + N + } + fn sample_to_buf(&self, rng: &mut R, output: &mut [F]) { let mut sum = F::zero(); - for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { + for (s, g) in output.iter_mut().zip(self.samplers.iter()) { *s = g.sample(rng); sum = sum + *s; } let invacc = F::one() / sum; - for s in samples.iter_mut() { + for s in output.iter_mut() { *s = *s * invacc; } - samples } } @@ -149,24 +150,25 @@ where } } -impl Distribution<[F; N]> for DirichletFromBeta +impl MultiDistribution for DirichletFromBeta where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; + fn sample_len(&self) -> usize { + N + } + fn sample_to_buf(&self, rng: &mut R, output: &mut [F]) { let mut acc = F::one(); - for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { + for (s, beta) in output.iter_mut().zip(self.samplers.iter()) { let beta_sample = beta.sample(rng); *s = acc * beta_sample; acc = acc * (F::one() - beta_sample); } - samples[N - 1] = acc; - samples + output[N - 1] = acc; } } @@ -208,7 +210,8 @@ where /// /// ``` /// use rand::prelude::*; -/// use rand_distr::Dirichlet; +/// use rand_distr::multi::Dirichlet; +/// use rand_distr::multi::MultiDistribution; /// /// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); /// let samples = dirichlet.sample(&mut rand::rng()); @@ -259,7 +262,7 @@ impl fmt::Display for Error { "failed to create required Gamma distribution for Dirichlet distribution" } Error::FailedToCreateBeta => { - "failed to create required Beta distribition for Dirichlet distribution" + "failed to create required Beta distribution for Dirichlet distribution" } }) } @@ -315,17 +318,20 @@ where } } -impl Distribution<[F; N]> for Dirichlet +impl MultiDistribution for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { + fn sample_len(&self) -> usize { + N + } + fn sample_to_buf(&self, rng: &mut R, output: &mut [F]) { match &self.repr { - DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), - DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_buf(rng, output), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_buf(rng, output), } } } @@ -403,7 +409,7 @@ mod test { let alpha_sum: f64 = alpha.iter().sum(); let expected_mean = alpha.map(|x| x / alpha_sum); for i in 0..N { - assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); + average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); } } diff --git a/src/multi/mod.rs b/src/multi/mod.rs new file mode 100644 index 0000000..b191440 --- /dev/null +++ b/src/multi/mod.rs @@ -0,0 +1,30 @@ +//! Contains Multi-dimensional distributions. +//! +//! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations. +//! All multi-dimensional distributions implement `MultiDistribution` instead of the `Distribution` trait. + +use alloc::vec::Vec; +use rand::Rng; + +/// This trait allows to sample from a multi-dimensional distribution without extra allocations. +/// For convenience it also provides a `sample` method which returns the result as a `Vec`. +pub trait MultiDistribution { + /// returns the length of one sample (dimension of the distribution) + fn sample_len(&self) -> usize; + /// samples from the distribution and writes the result to `buf` + fn sample_to_buf(&self, rng: &mut R, buf: &mut [T]); + /// samples from the distribution and returns the result as a `Vec`, to avoid extra allocations use `sample_to_buf` + fn sample(&self, rng: &mut R) -> Vec + where + T: Default, + { + let mut buf = Vec::new(); + buf.resize_with(self.sample_len(), || T::default()); + self.sample_to_buf(rng, &mut buf); + buf + } +} + +pub use dirichlet::Dirichlet; + +mod dirichlet; diff --git a/tests/value_stability.rs b/tests/value_stability.rs index 2eb263e..002b8b4 100644 --- a/tests/value_stability.rs +++ b/tests/value_stability.rs @@ -500,13 +500,17 @@ fn weibull_stability() { #[cfg(feature = "alloc")] #[test] fn dirichlet_stability() { + use rand_distr::multi::MultiDistribution; + let mut rng = get_rng(223); assert_eq!( - rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + multi::Dirichlet::new([1.0, 2.0, 3.0]) + .unwrap() + .sample(&mut rng), [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] ); assert_eq!( - rng.sample(Dirichlet::new([8.0; 5]).unwrap()), + multi::Dirichlet::new([8.0; 5]).unwrap().sample(&mut rng), [ 0.17684200044809556, 0.29915953935953055, @@ -517,7 +521,9 @@ fn dirichlet_stability() { ); // Test stability for the case where all alphas are less than 0.1. assert_eq!( - rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05]) + .unwrap() + .sample(&mut rng), [ 0.00027580456855692104, 2.296135759821706e-20,