From 423eff9aa9cc4e59d19519d58856a27eff3536b3 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Mon, 23 Oct 2023 17:34:52 -0400 Subject: [PATCH] nn.dataset.GaussianMixture: minor fix to the docstring and standard deviation parameter name (#445) --- src/ott/problems/nn/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ott/problems/nn/dataset.py b/src/ott/problems/nn/dataset.py index eaea1cae1..7c6f94ebd 100644 --- a/src/ott/problems/nn/dataset.py +++ b/src/ott/problems/nn/dataset.py @@ -52,14 +52,14 @@ class GaussianMixture: batch_size: batch size of the samples init_rng: initial PRNG key - scale: scale of the individual Gaussian samples - variance: the variance of the individual Gaussian samples + scale: scale of the Gaussian means + std: the standard deviation of the individual Gaussian samples """ name: Name_t batch_size: int init_rng: jax.random.PRNGKeyArray scale: float = 5.0 - variance: float = 0.5 + std: float = 0.5 def __post_init__(self): gaussian_centers = { @@ -101,7 +101,7 @@ def _create_sample_generators(self) -> Iterator[jnp.array]: rng1, rng2, rng = jax.random.split(rng, 3) means = jax.random.choice(rng1, self.centers, (self.batch_size,)) normal_samples = jax.random.normal(rng2, (self.batch_size, 2)) - samples = self.scale * means + self.variance ** 2 * normal_samples + samples = self.scale * means + (self.std ** 2) * normal_samples yield samples