From d51af7b127d280a46351bb9a596d2f369f545dc6 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Wed, 11 Jun 2025 20:01:04 +0200 Subject: [PATCH 01/11] =?UTF-8?q?=F0=9F=90=9B=20fix(stats.py)=20correct=20?= =?UTF-8?q?shape=20condition=20in=20add=5Fsamples?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plaid/utils/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index 9685328..362d6a9 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -82,7 +82,7 @@ def add_samples(self, x: np.ndarray) -> None: """ if x.ndim == 1: if self.min is not None: - if self.min.size == 1: + if self.min.shape[1] == 1: # n_samples x 1 x = x.reshape((-1, 1)) else: From 1f4dd101d2c949a0d5ef6245cffc43e10e01d88a Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Wed, 11 Jun 2025 20:04:56 +0200 Subject: [PATCH 02/11] fix(test_sample.py) update after ruff formatting --- tests/containers/test_sample.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/containers/test_sample.py b/tests/containers/test_sample.py index 7ae1221..ca50ed3 100644 --- a/tests/containers/test_sample.py +++ b/tests/containers/test_sample.py @@ -549,7 +549,9 @@ def test_get_zone_names(self, sample, base_name): base_name=base_name, ) assert sample.get_zone_names(base_name) == ["zone_name_1", "zone_name_2"] - assert sorted(sample.get_zone_names(base_name, unique = True)) == sorted(["zone_name_1", "zone_name_2"]) + assert sorted(sample.get_zone_names(base_name, unique=True)) == sorted( + ["zone_name_1", "zone_name_2"] + ) assert sample.get_zone_names(base_name, full_path=True) == [ f"{base_name}/zone_name_1", f"{base_name}/zone_name_2", From d7544b951c273ffca4db5d2c4be686e5a9e93e69 Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Tue, 17 Jun 2025 23:01:19 +0200 Subject: [PATCH 03/11] (stats) improve OnlineStatistics and Stats classes, better docstrings, better type checking, better modularity --- src/plaid/utils/stats.py | 238 ++++++++++++++++++++++++++++++--------- 1 file changed, 187 insertions(+), 51 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index 362d6a9..8febb5c 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -10,7 +10,7 @@ # %% Imports import logging -from typing import Union +from typing import List, Union import numpy as np @@ -60,7 +60,18 @@ def aggregate_stats( class OnlineStatistics(object): - """OnlineStatistics is a class for computing online statistics (e.g., min, max, mean, variance, and standard deviation) of numpy arrays.""" + """OnlineStatistics is a class for computing online statistics of numpy arrays. + + This class computes running statistics (min, max, mean, variance, std) for streaming data + without storing all samples in memory. + + Example: + >>> stats = OnlineStatistics() + >>> stats.add_samples(np.array([[1, 2], [3, 4]])) + >>> stats.add_samples(np.array([[5, 6]])) + >>> print(stats.get_stats()['mean']) + [[3. 4.]] + """ def __init__(self) -> None: """Initialize an empty OnlineStatistics object.""" @@ -70,6 +81,7 @@ def __init__(self) -> None: self.mean: np.ndarray = None self.var: np.ndarray = None self.std: np.ndarray = None + self._n_features: int = None # Add feature dimension tracking def add_samples(self, x: np.ndarray) -> None: """Add samples to compute statistics for. @@ -79,24 +91,37 @@ def add_samples(self, x: np.ndarray) -> None: Raises: ShapeError: Raised when there is an inconsistency in the shape of the input array. + ValueError: Raised when input contains NaN or Inf values. """ + # Validate input + if not isinstance(x, np.ndarray): + raise TypeError("Input must be a numpy array") + + if np.any(~np.isfinite(x)): + raise ValueError("Input contains NaN or Inf values") + + # Handle 1D arrays if x.ndim == 1: if self.min is not None: if self.min.shape[1] == 1: - # n_samples x 1 x = x.reshape((-1, 1)) else: - # 1 x n_features x = x.reshape((1, -1)) - else: # pragma: no cover - raise ShapeError( - "can't determine if input array with ndim=1, is 1 x n_features or n_samples x 1" - ) + else: + x = x.reshape((-1, 1)) # Default to column vector + + # Handle n-dimensional arrays elif x.ndim > 2: - # suppose last dim is features dim, all previous dims are space - # dims and are aggregated x = x.reshape((-1, x.shape[-1])) + # Validate feature dimension + if self._n_features is not None and x.shape[1] != self._n_features: + raise ShapeError( + f"Input has {x.shape[1]} features but expected {self._n_features}" + ) + else: + self._n_features = x.shape[1] + added_n_samples = len(x) added_min = np.min(x, axis=0, keepdims=True) added_max = np.max(x, axis=0, keepdims=True) @@ -166,12 +191,19 @@ def get_stats(self) -> dict[str, np.ndarray]: } -class Stats(object): - """Stats is a class for aggregating and computing statistics for datasets.""" +class Stats: + """Class for aggregating and computing statistics across datasets. + + The Stats class processes both scalar and field data from samples or datasets, + computing running statistics like min, max, mean, variance and standard deviation. + + Attributes: + _stats (dict[str, OnlineStatistics]): Dictionary mapping data identifiers to their statistics + """ def __init__(self): """Initialize an empty Stats object.""" - self._stats = {} + self._stats: dict[str, OnlineStatistics] = {} def add_dataset(self, dset: Dataset) -> None: """Add a dataset to compute statistics for. @@ -181,23 +213,32 @@ def add_dataset(self, dset: Dataset) -> None: """ self.add_samples(dset) - def add_samples(self, samples: Union[list[Sample], Dataset]) -> None: - """Add samples (or a dataset) to compute statistics for. + def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: + """Add samples or a dataset to compute statistics for. + + Processes both scalar and field data from the provided samples, + computing running statistics for each data identifier. Args: - samples (Union[list[Sample],Dataset]): The list of samples or a dataset to add. + samples (Union[List[Sample], Dataset]): List of samples or dataset to process + + Raises: + TypeError: If samples is not a List[Sample] or Dataset + ValueError: If a sample contains invalid data """ - # ---# Aggregate - new_data = {} + # Input validation + if not isinstance(samples, (list, Dataset)): + raise TypeError("samples must be a List[Sample] or Dataset") + + # Process each sample + new_data: dict[str, list] = {} + for sample in samples: - # ---# Scalars - for s_name in sample.get_scalar_names(): - if s_name not in new_data: - new_data[s_name] = [] - new_data[s_name].append(sample.get_scalar(s_name)) + # Process scalars + self._process_scalar_data(sample, new_data) - # ---# Fields - # TODO + # Process fields + self._process_field_data(sample, new_data) # ---# Categorical # TODO @@ -208,42 +249,137 @@ def add_samples(self, samples: Union[list[Sample], Dataset]) -> None: # ---# TemporalSupport # TODO - # ---# Process - for name in new_data: - # new_shapes = [value.shape for value in new_data[name] if value.shape!=new_data[name][0].shape] - # has_same_shape = (len(new_shapes)==0) - has_same_shape = True + # Update statistics + self._update_statistics(new_data) + + def get_stats( + self, identifiers: list[str] = None + ) -> dict[str, dict[str, np.ndarray]]: + """Get computed statistics for specified data identifiers. - if has_same_shape: - new_data[name] = np.array(new_data[name]) - else: # pragma: no cover ### remove "no cover" when "has_same_shape = True" is no longer used - if name in self._stats: - self._stats[name].flatten_array() - new_data[name] = np.concatenate( - [np.ravel(value) for value in new_data[name]] - ) + Args: + identifiers (list[str], optional): List of data identifiers to retrieve. + If None, returns statistics for all identifiers. - if new_data[name].ndim == 1: - new_data[name] = new_data[name].reshape((-1, 1)) + Returns: + dict[str, dict[str, np.ndarray]]: Dictionary mapping identifiers to their statistics + """ + if identifiers is None: + identifiers = self.get_available_statistics() + stats = {} + for identifier in identifiers: + if identifier in self._stats: + stats[identifier] = {} + for stat_name, stat_value in ( + self._stats[identifier].get_stats().items() + ): + stats[identifier][stat_name] = np.squeeze(stat_value) + + return stats + + def get_available_statistics(self) -> list[str]: + """Get list of data identifiers with computed statistics. + + Returns: + list[str]: List of data identifiers + """ + return sorted(self._stats.keys()) + + def clear_statistics(self) -> None: + """Clear all computed statistics.""" + self._stats.clear() + + def merge_stats(self, other: "Stats") -> None: + """Merge statistics from another Stats object. + + Args: + other (Stats): Stats object to merge with + """ + for name, stats in other._stats.items(): if name not in self._stats: self._stats[name] = OnlineStatistics() + for time_series in stats.get_stats().values(): + self._stats[name].add_samples(time_series) - self._stats[name].add_samples(new_data[name]) + def _process_scalar_data(self, sample: Sample, data_dict: dict[str, list]) -> None: + """Process scalar data from a sample. - def get_stats(self) -> dict[str, dict[str, np.ndarray]]: - """Get computed statistics for different data identifiers. + Args: + sample (Sample): Sample containing scalar data + data_dict (dict[str, list]): Dictionary to store processed data + """ + for name in sample.get_scalar_names(): + if name not in data_dict: + data_dict[name] = [] + value = sample.get_scalar(name) + if value is not None: + data_dict[name].append(value) - Returns: - dict[str,dict[str,np.ndarray]]: A dictionary containing computed statistics for different data identifiers. + def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> None: + """Process field data from a sample. + + Args: + sample (Sample): Sample containing field data + data_dict (dict[str, list]): Dictionary to store processed data """ - stats = {} - for identifier in self._stats: - stats[identifier] = {} - for stat_name, stat_value in self._stats[identifier].get_stats().items(): - stats[identifier][stat_name] = np.squeeze(stat_value) + for time in sample.get_all_mesh_times(): + for base in sample.get_base_names(time=time): + for zone in sample.get_zone_names(base_name=base, time=time): + for field_name in sample.get_field_names( + zone_name=zone, base_name=base, time=time + ): + if field_name not in data_dict: + data_dict[field_name] = [] + field = sample.get_field( + field_name, zone_name=zone, base_name=base, time=time + ) + if field is not None: + data_dict[field_name].append(field) + + def _update_statistics(self, new_data: dict[str, list]) -> None: + """Update running statistics with new data. - return stats + Args: + new_data (dict[str, list]): Dictionary containing new data to process + """ + for name, values in new_data.items(): + if len(values) > 0: + if name not in self._stats: + self._stats[name] = OnlineStatistics() + + # Convert to numpy array and reshape if needed + data = np.asarray(values) + if data.ndim == 1: + data = data.reshape(-1, 1) + + try: + self._stats[name].add_samples(data) + except Exception as e: + logging.warning(f"Failed to process {name}: {str(e)}") + + # # old version of _update_statistics logic + # for name in new_data: + # # new_shapes = [value.shape for value in new_data[name] if value.shape!=new_data[name][0].shape] + # # has_same_shape = (len(new_shapes)==0) + # has_same_shape = True + + # if has_same_shape: + # new_data[name] = np.array(new_data[name]) + # else: # pragma: no cover ### remove "no cover" when "has_same_shape = True" is no longer used + # if name in self._stats: + # self._stats[name].flatten_array() + # new_data[name] = np.concatenate( + # [np.ravel(value) for value in new_data[name]] + # ) + + # if new_data[name].ndim == 1: + # new_data[name] = new_data[name].reshape((-1, 1)) + + # if name not in self._stats: + # self._stats[name] = OnlineStatistics() + + # self._stats[name].add_samples(new_data[name]) # TODO : FAIRE DEUX FONCTIONS : # - compute_stats(samples) -> stats From 150b55dce053216ccd38fc545ac9d597aff3faa8 Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Sat, 28 Jun 2025 00:19:44 +0200 Subject: [PATCH 04/11] (stats) implement logic to check field shapes and flatten if incoherent between samples + improve tests for type checking + tests for new methods --- examples/utils/stats_example.py | 53 +++++-- src/plaid/utils/stats.py | 239 +++++++++++++++++++++++--------- tests/utils/test_stats.py | 179 ++++++++++++++++++++++++ 3 files changed, 394 insertions(+), 77 deletions(-) diff --git a/examples/utils/stats_example.py b/examples/utils/stats_example.py index a270102..1155f17 100644 --- a/examples/utils/stats_example.py +++ b/examples/utils/stats_example.py @@ -33,8 +33,8 @@ # %% def sprint(stats: dict): print("Stats:") - for k in stats: - print(" - {} -> {}".format(k, stats[k])) + for k,v in stats.items(): + print(f" - {k:>9} -> {v}") # %% [markdown] @@ -58,7 +58,8 @@ def sprint(stats: dict): # %% # First batch of samples first_batch_samples = 3.0 * np.random.randn(100, 3) + 10.0 -print(f"{first_batch_samples.shape = }") +print() +print(f"=== {first_batch_samples.shape = }") stats_computer.add_samples(first_batch_samples) stats = stats_computer.get_stats() @@ -67,7 +68,8 @@ def sprint(stats: dict): # %% second_batch_samples = 10.0 * np.random.randn(1000, 3) - 1.0 -print(f"{second_batch_samples.shape = }") +print() +print(f"=== {second_batch_samples.shape = }") stats_computer.add_samples(second_batch_samples) stats = stats_computer.get_stats() @@ -79,7 +81,8 @@ def sprint(stats: dict): # %% total_samples = np.concatenate((first_batch_samples, second_batch_samples), axis=0) -print(f"{total_samples.shape = }") +print() +print(f"=== {total_samples.shape = }") new_stats_computer = OnlineStatistics() new_stats_computer.add_samples(total_samples) @@ -110,14 +113,23 @@ def sprint(stats: dict): nb_samples = 11 samples = [Sample() for _ in range(nb_samples)] -spatial_shape_max = 20 +spatial_shape_max = 5 # for sample in samples: sample.add_scalar("test_scalar", np.random.randn()) - sample.init_base(2, 3, "test_base") + sample.add_scalar("test_ND_scalar", np.random.randn(3)) + sample.init_base(2, 3,) zone_shape = np.array([0, 0, 0]) - sample.init_zone(zone_shape, zone_name="test_zone") + sample.init_zone(zone_shape) sample.add_field("test_field", np.random.randn(spatial_shape_max)) + sample.init_zone(zone_shape, zone_name="test_zone_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2') + sample.init_base(2, 3, "test_base_2") + zone_shape = np.array([0, 0, 0]) + sample.init_zone(zone_shape, zone_name="test_zone_1", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), base_name="test_base_2") + sample.init_zone(zone_shape, zone_name="test_zone_2", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2', base_name="test_base_2") stats.add_samples(samples) @@ -132,21 +144,38 @@ def sprint(stats: dict): # ### Feed Stats with more Samples # %% -nb_samples = 11 -spatial_shape_max = 20 +nb_samples = 13 +# spatial_shape_max = 20 samples = [Sample() for _ in range(nb_samples)] for sample in samples: sample.add_scalar("test_scalar", np.random.randn()) + sample.init_base(2, 3,) + zone_shape = np.array([0, 0, 0]) + sample.init_zone(zone_shape) + sample.add_field("test_field", np.random.randn(spatial_shape_max)) + sample.init_zone(zone_shape, zone_name="test_zone_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2') + sample.init_base(2, 3, "test_base_2") + zone_shape = np.array([0, 0, 0]) + sample.init_zone(zone_shape, zone_name="test_zone_1", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), base_name="test_base_2") + sample.init_zone(zone_shape, zone_name="test_zone_2", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2', base_name="test_base_2") + sample.init_base(2, 3, "test_base") zone_shape = np.array([0, 0, 0]) - sample.init_zone(zone_shape, zone_name="test_zone") - sample.add_field("test_field_same_size", np.random.randn(7)) + sample.init_zone(zone_shape, zone_name="test_zone", base_name="test_base") + sample.add_field("test_field_same_size", np.random.randn(spatial_shape_max), zone_name="test_zone", base_name="test_base") sample.add_field( "test_field", np.random.randn(np.random.randint(spatial_shape_max // 2, spatial_shape_max)), + zone_name="test_zone", base_name="test_base" ) +for sample in samples: + print(sample) +print(f"{len(samples)=}") stats.add_samples(samples) # %% [markdown] diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index 8febb5c..c8f2a41 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -9,14 +9,19 @@ # %% Imports +import copy import logging from typing import List, Union +try: # pragma: no cover + from typing import Self +except ImportError: # pragma: no cover + from typing import Any as Self + import numpy as np from plaid.containers.dataset import Dataset from plaid.containers.sample import Sample -from plaid.utils.base import ShapeError logger = logging.getLogger(__name__) logging.basicConfig( @@ -34,12 +39,12 @@ def aggregate_stats( This function calculates aggregated statistics, such as the total number of samples, mean, and variance, by taking into account the statistics computed for each batch of data. - cf: https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques) + cf: [Variance from (cardinal,mean,variance) of several statistical series](https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques)#Formules) Args: - sizes (np.ndarray): An array containing the sizes (number of samples) of each batch. - means (np.ndarray): An array containing the means of each batch. - vars (np.ndarray): An array containing the variances of each batch. + sizes (np.ndarray): An array containing the sizes (number of samples) of each batch. Expect shape (n_batches,1). + means (np.ndarray): An array containing the means of each batch. Expect shape (n_batches, n_features). + vars (np.ndarray): An array containing the variances of each batch. Expect shape (n_batches, n_features). Returns: tuple[np.ndarray,np.ndarray,np.ndarray]: A tuple containing the aggregated statistics in the following order: @@ -47,10 +52,15 @@ def aggregate_stats( - Weighted mean calculated from the batch means. - Weighted variance calculated from the batch variances, considering the means. """ - total_n_samples = np.sum(sizes, keepdims=True) - total_mean = np.sum(sizes * means, keepdims=True) / total_n_samples + assert sizes.ndim == 1 + assert means.ndim == 2 + assert len(sizes) == len(means) + assert means.shape == vars.shape + sizes = sizes.reshape((-1, 1)) + total_n_samples = np.sum(sizes) + total_mean = np.sum(sizes * means, axis=0, keepdims=True) / total_n_samples total_var = ( - np.sum(sizes * (vars + (total_mean - means) ** 2), keepdims=True) + np.sum(sizes * (vars + (total_mean - means) ** 2), axis=0, keepdims=True) / total_n_samples ) return total_n_samples, total_mean, total_var @@ -76,18 +86,20 @@ class OnlineStatistics(object): def __init__(self) -> None: """Initialize an empty OnlineStatistics object.""" self.n_samples: int = 0 + self.n_features: int = None + self.n_points: int = None self.min: np.ndarray = None self.max: np.ndarray = None self.mean: np.ndarray = None self.var: np.ndarray = None self.std: np.ndarray = None - self._n_features: int = None # Add feature dimension tracking - def add_samples(self, x: np.ndarray) -> None: + def add_samples(self, x: np.ndarray, n_samples: int = None) -> None: """Add samples to compute statistics for. Args: - x (np.ndarray): The input numpy array containing samples data. + x (np.ndarray): The input numpy array containing samples data. Expect 2D arrays with shape (n_samples, n_features). + n_samples (int, optional): The number of samples in the input array. If not provided, it will be inferred from the shape of `x`. Use this argument when the input array has already been flattened because of shape inconsistencies. Raises: ShapeError: Raised when there is an inconsistency in the shape of the input array. @@ -112,17 +124,23 @@ def add_samples(self, x: np.ndarray) -> None: # Handle n-dimensional arrays elif x.ndim > 2: + # if we have array of shape (n_samples, n_points, n_features) + # it will be reshaped to (n_samples * n_points, n_features) x = x.reshape((-1, x.shape[-1])) - # Validate feature dimension - if self._n_features is not None and x.shape[1] != self._n_features: - raise ShapeError( - f"Input has {x.shape[1]} features but expected {self._n_features}" - ) - else: - self._n_features = x.shape[1] + if self.n_features is None: + self.n_features = x.shape[1] + + if x.shape[1] != self.n_features: + # it means that stats where previously on a per-point mode, + # but it is no longer possible as the new added samples have a different shape + # so we need to shift the stats to a per-sample mode, and then flatten the stats array + self.flatten_array() + n_samples = x.shape[0] + x = x.reshape((-1, 1)) - added_n_samples = len(x) + added_n_samples = len(x) if n_samples is None else n_samples + added_n_points = x.size added_min = np.min(x, axis=0, keepdims=True) added_max = np.max(x, axis=0, keepdims=True) added_mean = np.mean(x, axis=0, keepdims=True) @@ -136,46 +154,78 @@ def add_samples(self, x: np.ndarray) -> None: or (self.var is None) ): self.n_samples = added_n_samples + self.n_points = added_n_points self.min = added_min self.max = added_max self.mean = added_mean self.var = added_var else: - self.min = np.min(np.concatenate((self.min, added_min), axis=0), axis=0) - self.max = np.max(np.concatenate((self.max, added_max), axis=0), axis=0) - # new_n_samples = self.n_samples + added_n_samples - # new_mean = ( - # self.n_samples * self.mean + added_n_samples * added_mean - # ) / new_n_samples - self.n_samples, self.mean, self.var = aggregate_stats( - np.concatenate( - [ - self.n_samples + np.zeros(self.mean.shape, dtype=int), - added_n_samples + np.zeros(added_mean.shape, dtype=int), - ] - ), - np.concatenate([self.mean, added_mean]), - np.concatenate([self.var, added_var]), + self.min = np.min( + np.concatenate((self.min, added_min), axis=0), axis=0, keepdims=True ) + self.max = np.max( + np.concatenate((self.max, added_max), axis=0), axis=0, keepdims=True + ) + if self.n_features > 1: + # feature not flattened, we are on a per-sample mode + self.n_points += added_n_points + self.n_samples, self.mean, self.var = aggregate_stats( + np.array([self.n_samples, added_n_samples]), + np.concatenate([self.mean, added_mean]), + np.concatenate([self.var, added_var]), + ) + else: + # feature flattened, we are on a per-point mode + self.n_samples += added_n_samples + self.n_points, self.mean, self.var = aggregate_stats( + np.array([self.n_points, added_n_points]), + np.concatenate([self.mean, added_mean]), + np.concatenate([self.var, added_var]), + ) + + self.std = np.sqrt(self.var) - # # cf: https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques) - # self.var = (self.n_samples * (self.var + (new_mean - self.mean)**2) + added_n_samples*(added_var + (new_mean - added_mean)**2)) / new_n_samples - # self.n_samples = new_n_samples - # self.mean = new_mean + def merge_stats(self, other: Self) -> None: + """Merge statistics from another instance. + Args: + other (Self): The other instance to merge statistics from. + """ + if not isinstance(other, self.__class__): + raise TypeError("Can only merge with another instance of the same class") + + if self.n_features != other.n_features: + # flatten both + self.flatten_array() + other = copy.deepcopy(other) + other.flatten_array() + + self.min = np.min(np.concatenate((self.min, other.min), axis=0), axis=0) + self.max = np.max(np.concatenate((self.max, other.max), axis=0), axis=0) + self.n_points += other.n_points + self.n_samples, self.mean, self.var = aggregate_stats( + np.array([self.n_samples, other.n_samples]), + np.concatenate([self.mean, other.mean]), + np.concatenate([self.var, other.var]), + ) self.std = np.sqrt(self.var) def flatten_array(self) -> None: """When a shape incoherence is detected, you should call this function.""" self.min = np.min(self.min, keepdims=True) self.max = np.max(self.max, keepdims=True) + self.n_points = self.n_samples * self.n_features assert self.mean.shape == self.var.shape - self.n_samples, self.mean, self.var = aggregate_stats( - np.zeros(self.mean.shape, dtype=int) + self.n_samples, self.mean, self.var + self.n_points, self.mean, self.var = aggregate_stats( + np.array([self.n_samples] * self.n_features), + self.mean.reshape(-1, 1), + self.var.reshape(-1, 1), ) self.std = np.sqrt(self.var) - def get_stats(self) -> dict[str, np.ndarray]: + self.n_features = 1 + + def get_stats(self) -> dict[str, Union[int, np.ndarray[float]]]: """Get computed statistics. Returns: @@ -183,6 +233,8 @@ def get_stats(self) -> dict[str, np.ndarray]: """ return { "n_samples": self.n_samples, + "n_points": self.n_points, + "n_features": self.n_features, "min": self.min, "max": self.max, "mean": self.mean, @@ -204,6 +256,7 @@ class Stats: def __init__(self): """Initialize an empty Stats object.""" self._stats: dict[str, OnlineStatistics] = {} + self._field_is_flattened: dict[str, bool] = {} def add_dataset(self, dset: Dataset) -> None: """Add a dataset to compute statistics for. @@ -240,8 +293,8 @@ def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: # Process fields self._process_field_data(sample, new_data) - # ---# Categorical - # TODO + # ---# Time-Series + self._process_time_series_data(sample, new_data) # ---# SpatialSupport (Meshes) # TODO @@ -249,6 +302,9 @@ def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: # ---# TemporalSupport # TODO + # ---# Categorical + # TODO + # Update statistics self._update_statistics(new_data) @@ -290,17 +346,20 @@ def clear_statistics(self) -> None: """Clear all computed statistics.""" self._stats.clear() - def merge_stats(self, other: "Stats") -> None: + def merge_stats(self, other: Self) -> None: """Merge statistics from another Stats object. Args: other (Stats): Stats object to merge with """ + for name, stats in self._stats.items(): + print(f"=== self {name=} -> {stats.get_stats()=}") for name, stats in other._stats.items(): + print(f"=== other {name=} -> {stats.get_stats()=}") if name not in self._stats: - self._stats[name] = OnlineStatistics() - for time_series in stats.get_stats().values(): - self._stats[name].add_samples(time_series) + self._stats[name] = copy.deepcopy(stats) + else: + self._stats[name].merge_stats(stats) def _process_scalar_data(self, sample: Sample, data_dict: dict[str, list]) -> None: """Process scalar data from a sample. @@ -314,7 +373,25 @@ def _process_scalar_data(self, sample: Sample, data_dict: dict[str, list]) -> No data_dict[name] = [] value = sample.get_scalar(name) if value is not None: - data_dict[name].append(value) + data_dict[name].append(np.array(value).reshape((1, -1))) + + def _process_time_series_data( + self, sample: Sample, data_dict: dict[str, list] + ) -> None: + """Process time series data from a sample. + + Args: + sample (Sample): Sample containing time series data + data_dict (dict[str, list]): Dictionary to store processed data + """ + for name in sample.get_time_series_names(): + if name not in data_dict: + data_dict[f"timestamps/{name}"] = [] + data_dict[f"time_series/{name}"] = [] + timestamps, values = sample.get_time_series(name) + if timestamps is not None and values is not None: + data_dict[f"timestamps/{name}"].append(timestamps.reshape((1, -1))) + data_dict[f"time_series/{name}"].append(values.reshape((1, -1))) def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> None: """Process field data from a sample. @@ -324,18 +401,39 @@ def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> Non data_dict (dict[str, list]): Dictionary to store processed data """ for time in sample.get_all_mesh_times(): - for base in sample.get_base_names(time=time): - for zone in sample.get_zone_names(base_name=base, time=time): + for base_name in sample.get_base_names(time=time): + for zone_name in sample.get_zone_names(base_name=base_name, time=time): for field_name in sample.get_field_names( - zone_name=zone, base_name=base, time=time + zone_name=zone_name, base_name=base_name, time=time ): - if field_name not in data_dict: - data_dict[field_name] = [] + stat_key = f"{base_name}/{zone_name}/{field_name}" + if stat_key not in data_dict: + data_dict[stat_key] = [] field = sample.get_field( - field_name, zone_name=zone, base_name=base, time=time + field_name, + zone_name=zone_name, + base_name=base_name, + time=time, ) if field is not None: - data_dict[field_name].append(field) + # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] + if len( + data_dict[stat_key] + ) > 0 and not self._field_is_flattened.get(stat_key, False): + prev_shape = data_dict[stat_key][0].shape + if field.shape != prev_shape: + # set this stat as flattened + self._field_is_flattened[stat_key] = True + # flatten corresponding stat + if stat_key in self._stats: + self._stats[stat_key].flatten_array() + + if self._field_is_flattened.get(stat_key, False): + field = field.reshape((-1, 1)) + else: + field = field.reshape((1, -1)) + + data_dict[stat_key].append(field) def _update_statistics(self, new_data: dict[str, list]) -> None: """Update running statistics with new data. @@ -343,20 +441,31 @@ def _update_statistics(self, new_data: dict[str, list]) -> None: Args: new_data (dict[str, list]): Dictionary containing new data to process """ - for name, values in new_data.items(): - if len(values) > 0: + for name, list_of_arrays in new_data.items(): + if len(list_of_arrays) > 0: if name not in self._stats: self._stats[name] = OnlineStatistics() + for i in range(len(list_of_arrays)): + if ( + isinstance(list_of_arrays[i], np.ndarray) + and list_of_arrays[i].ndim == 1 + ): + list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) + + if self._field_is_flattened.get(name, False): + # flatten all arrays in list_of_arrays + n_samples = len(list_of_arrays) + for i in range(len(list_of_arrays)): + list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) + else: + n_samples = None + # Convert to numpy array and reshape if needed - data = np.asarray(values) - if data.ndim == 1: - data = data.reshape(-1, 1) - - try: - self._stats[name].add_samples(data) - except Exception as e: - logging.warning(f"Failed to process {name}: {str(e)}") + data = np.concatenate(list_of_arrays) + assert data.ndim == 2 + + self._stats[name].add_samples(data, n_samples=n_samples) # # old version of _update_statistics logic # for name in new_data: diff --git a/tests/utils/test_stats.py b/tests/utils/test_stats.py index 576c99c..9a894d6 100644 --- a/tests/utils/test_stats.py +++ b/tests/utils/test_stats.py @@ -10,6 +10,7 @@ import numpy as np import pytest +from plaid.containers.sample import Sample from plaid.utils.stats import OnlineStatistics, Stats # %% Fixtures @@ -50,6 +51,70 @@ def stats(): return Stats() +@pytest.fixture() +def sample_data_1(): + return np.random.randn(100, 1) + + +@pytest.fixture() +def sample_data_2(): + return np.random.randn(50, 1) + + +@pytest.fixture() +def sample_with_scalar(sample_data_1): + s = Sample() + s.add_scalar("foo", float(sample_data_1.mean())) + return s + + +@pytest.fixture() +def sample_with_field(sample_data_2): + s = Sample() + # 1. Initialize the CGNS tree + s.init_tree() + # 2. Create a base and a zone + s.init_base(topological_dim=3, physical_dim=3) + s.init_zone(zone_shape=np.array([sample_data_2.shape[0], 0, 0])) + # 3. Set node coordinates (required for a valid zone) + s.set_nodes(np.zeros((sample_data_2.shape[0], 3))) + # 4. Add a field named "bar" + s.add_field(name="bar", field=sample_data_2) + return s + + +@pytest.fixture() +def time_series_data(): + # 10 time steps, 1 feature + times = np.linspace(0, 1, 10) + values = np.random.randn(10) + return times, values + + +@pytest.fixture() +def time_series_data_of_different_size(): + # 5 time steps, 1 feature + times = np.linspace(0, 1, 5) + values = np.random.randn(5) + return times, values + + +@pytest.fixture() +def sample_with_time_series(time_series_data): + s = Sample() + times, values = time_series_data + s.add_time_series("ts1", time_sequence=times, values=values) + return s + + +@pytest.fixture() +def sample_with_time_series_of_different_size(time_series_data_of_different_size): + s = Sample() + times, values = time_series_data_of_different_size + s.add_time_series("ts1", time_sequence=times, values=values) + return s + + # %% Tests @@ -82,6 +147,38 @@ def test_get_stats(self, online_stats, np_samples_1): online_stats.add_samples(np_samples_1) online_stats.get_stats() + def test_invalid_input_type(self, online_stats): + with pytest.raises(TypeError): + online_stats.add_samples([1, 2, 3]) # List instead of ndarray + + def test_nan_inf_input(self, online_stats): + with pytest.raises(ValueError): + online_stats.add_samples(np.array([1, np.nan, 3])) + with pytest.raises(ValueError): + online_stats.add_samples(np.array([1, np.inf, 3])) + + def test_merge_stats(self, sample_data_1, sample_data_2): + stats1 = OnlineStatistics() + stats2 = OnlineStatistics() + stats1.add_samples(sample_data_1) + stats2.add_samples(sample_data_2) + n_samples_before = stats1.n_samples + n_samples_other = stats2.n_samples + mean_before = stats1.mean.copy() + other_mean = stats2.mean.copy() + # do the merging + stats1.merge_stats(stats2) + # check results + assert stats1.n_samples == n_samples_before + stats2.n_samples + print(f"{n_samples_before=}, {n_samples_other=}") + print(f"{mean_before=}, {other_mean=}") + expected_mean = ( + mean_before * n_samples_before + other_mean * n_samples_other + ) / (n_samples_before + n_samples_other) + print(f"{expected_mean=}") + print(f"{stats1.mean=}") + assert np.allclose(stats1.mean, expected_mean) + class Test_Stats: def test__init__(self, stats): @@ -96,3 +193,85 @@ def test_add_dataset(self, stats, dataset): def test_get_stats(self, stats, samples): stats.add_samples(samples) stats.get_stats() + + def test_invalid_input(self, stats): + with pytest.raises(TypeError): + stats.add_samples("invalid") + + def test_empty_samples(self, stats): + stats.add_samples([]) + assert len(stats.get_available_statistics()) == 0 + + def test_merge_stats(self, sample_with_scalar, sample_with_field): + # Create two Stats objects with different samples + stats1 = Stats() + stats2 = Stats() + stats1.add_samples([sample_with_scalar]) + stats2.add_samples([sample_with_field]) + # Merge stats2 into stats1 + stats1.merge_stats(stats2) + # Both keys should be present + keys = stats1.get_available_statistics() + assert "foo" in keys or "bar" in keys + # Check that statistics are present for merged keys + for key in keys: + s = stats1._stats[key] + assert s.n_samples > 0 + + def test_clear_statistics(self, stats, samples): + stats.add_samples(samples) + stats.clear_statistics() + assert len(stats.get_available_statistics()) == 0 + + def test_add_samples_with_time_series( + self, stats, sample_with_time_series, sample_with_time_series_of_different_size + ): + stats.add_samples([sample_with_time_series]) + stats.add_samples([sample_with_time_series]) + keys = stats.get_available_statistics() + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 20 + stats_dict = stat.get_stats() + assert "min" in stats_dict + assert "max" in stats_dict + assert "n_samples" in stats_dict + assert "mean" in stats_dict + assert "var" in stats_dict + assert "std" in stats_dict + assert stats_dict["mean"].shape == ( + 1, + len(sample_with_time_series.get_time_series("ts1")[1]), + ) + stats.add_samples([sample_with_time_series_of_different_size]) + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats._stats["time_series/ts1"] + # 2 samples of size 10 + 1 sample of size 5 -> 25 values + assert stat.n_samples == 3 + assert stat.n_points == 25 + stats_dict = stat.get_stats() + assert "min" in stats_dict + assert "max" in stats_dict + assert "n_samples" in stats_dict + assert "mean" in stats_dict + assert "var" in stats_dict + assert "std" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + + def test_merge_stats_with_time_series(self, sample_with_time_series): + stats1 = Stats() + stats2 = Stats() + stats1.add_samples([sample_with_time_series]) + stats2.add_samples([sample_with_time_series]) + print( + f"{stats1._stats['time_series/ts1'].n_samples=}, {stats2._stats['time_series/ts1'].n_samples=}" + ) + stats1.merge_stats(stats2) + keys = stats1.get_available_statistics() + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats1._stats["time_series/ts1"] + assert stat.n_samples == 2 From c123a7e73dbcf4914940f0b77dbb30be3a8396ae Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Sat, 28 Jun 2025 00:20:47 +0200 Subject: [PATCH 05/11] (stats) remove prints --- src/plaid/utils/stats.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index c8f2a41..1482f01 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -352,10 +352,7 @@ def merge_stats(self, other: Self) -> None: Args: other (Stats): Stats object to merge with """ - for name, stats in self._stats.items(): - print(f"=== self {name=} -> {stats.get_stats()=}") for name, stats in other._stats.items(): - print(f"=== other {name=} -> {stats.get_stats()=}") if name not in self._stats: self._stats[name] = copy.deepcopy(stats) else: From 83879292791bccbd7d387bcb4791bffd7185d67f Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 28 Jun 2025 09:34:12 +0200 Subject: [PATCH 06/11] feat(tests/utils/stats) add tests to improve coverage --- src/plaid/utils/stats.py | 9 ++++---- tests/utils/test_stats.py | 44 ++++++++++++++++++++------------------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index 1482f01..fb2fc57 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -199,6 +199,7 @@ def merge_stats(self, other: Self) -> None: self.flatten_array() other = copy.deepcopy(other) other.flatten_array() + assert self.min.shape == other.min.shape, f"Shape mismatch in OnlineStatistics merging" self.min = np.min(np.concatenate((self.min, other.min), axis=0), axis=0) self.max = np.max(np.concatenate((self.max, other.max), axis=0), axis=0) @@ -212,8 +213,8 @@ def merge_stats(self, other: Self) -> None: def flatten_array(self) -> None: """When a shape incoherence is detected, you should call this function.""" - self.min = np.min(self.min, keepdims=True) - self.max = np.max(self.max, keepdims=True) + self.min = np.min(self.min, keepdims=True).reshape(-1, 1) + self.max = np.max(self.max, keepdims=True).reshape(-1, 1) self.n_points = self.n_samples * self.n_features assert self.mean.shape == self.var.shape self.n_points, self.mean, self.var = aggregate_stats( @@ -422,7 +423,7 @@ def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> Non # set this stat as flattened self._field_is_flattened[stat_key] = True # flatten corresponding stat - if stat_key in self._stats: + if stat_key in self._stats: # TODO: ADD THIS IN TESTS self._stats[stat_key].flatten_array() if self._field_is_flattened.get(stat_key, False): @@ -447,7 +448,7 @@ def _update_statistics(self, new_data: dict[str, list]) -> None: if ( isinstance(list_of_arrays[i], np.ndarray) and list_of_arrays[i].ndim == 1 - ): + ): # TODO: ADD THIS IN TESTS list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) if self._field_is_flattened.get(name, False): diff --git a/tests/utils/test_stats.py b/tests/utils/test_stats.py index 9a894d6..d7e6db6 100644 --- a/tests/utils/test_stats.py +++ b/tests/utils/test_stats.py @@ -41,6 +41,11 @@ def np_samples_5(): return np.random.randn(400) +@pytest.fixture() +def np_samples_6(): + return np.random.randn(50, 1) + + @pytest.fixture() def online_stats(): return OnlineStatistics() @@ -52,34 +57,24 @@ def stats(): @pytest.fixture() -def sample_data_1(): - return np.random.randn(100, 1) - - -@pytest.fixture() -def sample_data_2(): - return np.random.randn(50, 1) - - -@pytest.fixture() -def sample_with_scalar(sample_data_1): +def sample_with_scalar(np_samples_3): s = Sample() - s.add_scalar("foo", float(sample_data_1.mean())) + s.add_scalar("foo", float(np_samples_3.mean())) return s @pytest.fixture() -def sample_with_field(sample_data_2): +def sample_with_field(np_samples_6): s = Sample() # 1. Initialize the CGNS tree s.init_tree() # 2. Create a base and a zone s.init_base(topological_dim=3, physical_dim=3) - s.init_zone(zone_shape=np.array([sample_data_2.shape[0], 0, 0])) + s.init_zone(zone_shape=np.array([np_samples_6.shape[0], 0, 0])) # 3. Set node coordinates (required for a valid zone) - s.set_nodes(np.zeros((sample_data_2.shape[0], 3))) + s.set_nodes(np.zeros((np_samples_6.shape[0], 3))) # 4. Add a field named "bar" - s.add_field(name="bar", field=sample_data_2) + s.add_field(name="bar", field=np_samples_6) return s @@ -134,6 +129,9 @@ def test_add_samples_3(self, online_stats, np_samples_3, np_samples_5): online_stats.min = np_samples_3 online_stats.add_samples(np_samples_5) + def test_add_samples_4(self, online_stats, np_samples_5): + online_stats.add_samples(np_samples_5) + def test_add_samples_already_present(self, online_stats, np_samples_1): online_stats.add_samples(np_samples_1) online_stats.add_samples(np_samples_1) @@ -157,18 +155,19 @@ def test_nan_inf_input(self, online_stats): with pytest.raises(ValueError): online_stats.add_samples(np.array([1, np.inf, 3])) - def test_merge_stats(self, sample_data_1, sample_data_2): + def test_merge_stats(self, np_samples_3, np_samples_4, np_samples_6): stats1 = OnlineStatistics() stats2 = OnlineStatistics() - stats1.add_samples(sample_data_1) - stats2.add_samples(sample_data_2) + stats1.add_samples(np_samples_3) + stats2.add_samples(np_samples_6) n_samples_before = stats1.n_samples n_samples_other = stats2.n_samples mean_before = stats1.mean.copy() other_mean = stats2.mean.copy() + stats3 = OnlineStatistics() + stats3.add_samples(np_samples_4) # do the merging stats1.merge_stats(stats2) - # check results assert stats1.n_samples == n_samples_before + stats2.n_samples print(f"{n_samples_before=}, {n_samples_other=}") print(f"{mean_before=}, {other_mean=}") @@ -178,7 +177,10 @@ def test_merge_stats(self, sample_data_1, sample_data_2): print(f"{expected_mean=}") print(f"{stats1.mean=}") assert np.allclose(stats1.mean, expected_mean) - + # other merging tests + with pytest.raises(TypeError): + stats1.merge_stats(0.) + stats1.merge_stats(stats3) class Test_Stats: def test__init__(self, stats): From e534477b72f88d41bf034dd2ef3bd1cad39e4356 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Sat, 28 Jun 2025 11:15:59 +0200 Subject: [PATCH 07/11] fix(stats) apply ruff formatting --- src/plaid/utils/stats.py | 10 +++++++--- tests/utils/test_stats.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index fb2fc57..d66170b 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -199,7 +199,9 @@ def merge_stats(self, other: Self) -> None: self.flatten_array() other = copy.deepcopy(other) other.flatten_array() - assert self.min.shape == other.min.shape, f"Shape mismatch in OnlineStatistics merging" + assert self.min.shape == other.min.shape, ( + "Shape mismatch in OnlineStatistics merging" + ) self.min = np.min(np.concatenate((self.min, other.min), axis=0), axis=0) self.max = np.max(np.concatenate((self.max, other.max), axis=0), axis=0) @@ -423,7 +425,9 @@ def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> Non # set this stat as flattened self._field_is_flattened[stat_key] = True # flatten corresponding stat - if stat_key in self._stats: # TODO: ADD THIS IN TESTS + if ( + stat_key in self._stats + ): # TODO: ADD THIS IN TESTS self._stats[stat_key].flatten_array() if self._field_is_flattened.get(stat_key, False): @@ -448,7 +452,7 @@ def _update_statistics(self, new_data: dict[str, list]) -> None: if ( isinstance(list_of_arrays[i], np.ndarray) and list_of_arrays[i].ndim == 1 - ): # TODO: ADD THIS IN TESTS + ): # TODO: ADD THIS IN TESTS list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) if self._field_is_flattened.get(name, False): diff --git a/tests/utils/test_stats.py b/tests/utils/test_stats.py index d7e6db6..8004922 100644 --- a/tests/utils/test_stats.py +++ b/tests/utils/test_stats.py @@ -179,9 +179,10 @@ def test_merge_stats(self, np_samples_3, np_samples_4, np_samples_6): assert np.allclose(stats1.mean, expected_mean) # other merging tests with pytest.raises(TypeError): - stats1.merge_stats(0.) + stats1.merge_stats(0.0) stats1.merge_stats(stats3) + class Test_Stats: def test__init__(self, stats): pass From c89f2f3bc4e7d5575d01a5fcbc4affa77121a611 Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Mon, 30 Jun 2025 00:10:17 +0200 Subject: [PATCH 08/11] (stats) fix flattening of time_series --- src/plaid/utils/stats.py | 76 +++++++++--- tests/utils/test_stats.py | 247 +++++++++++++++++++++++++++++++++----- 2 files changed, 281 insertions(+), 42 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index d66170b..dc80039 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -259,7 +259,7 @@ class Stats: def __init__(self): """Initialize an empty Stats object.""" self._stats: dict[str, OnlineStatistics] = {} - self._field_is_flattened: dict[str, bool] = {} + self._feature_is_flattened: dict[str, bool] = {} def add_dataset(self, dset: Dataset) -> None: """Add a dataset to compute statistics for. @@ -290,6 +290,7 @@ def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: new_data: dict[str, list] = {} for sample in samples: + print(f"{sample=}") # Process scalars self._process_scalar_data(sample, new_data) @@ -385,13 +386,57 @@ def _process_time_series_data( data_dict (dict[str, list]): Dictionary to store processed data """ for name in sample.get_time_series_names(): - if name not in data_dict: - data_dict[f"timestamps/{name}"] = [] - data_dict[f"time_series/{name}"] = [] - timestamps, values = sample.get_time_series(name) - if timestamps is not None and values is not None: - data_dict[f"timestamps/{name}"].append(timestamps.reshape((1, -1))) - data_dict[f"time_series/{name}"].append(values.reshape((1, -1))) + print(f" - {name=}") + timestamps, time_series = sample.get_time_series(name) + timestamps = timestamps.reshape((1, -1)) + time_series = time_series.reshape((1, -1)) + + timestamps_name = f"timestamps/{name}" + time_series_name = f"time_series/{name}" + print( + f" - {timestamps_name} is flattened -> {self._feature_is_flattened[timestamps_name] if timestamps_name in self._feature_is_flattened else None}" + ) + print( + f" - {time_series_name} is flattened -> {self._feature_is_flattened[time_series_name] if time_series_name in self._feature_is_flattened else None}" + ) + if timestamps_name not in data_dict: + print(f" - {timestamps_name} was not in {data_dict.keys()=}") + assert time_series_name not in data_dict + data_dict[timestamps_name] = [] + data_dict[time_series_name] = [] + if timestamps is not None and time_series is not None: + print( + f" - add timestamps and time_series to data_dict, {data_dict.keys()=}" + ) + + # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] + if len( + data_dict[time_series_name] + ) > 0 and not self._feature_is_flattened.get(time_series_name, False): + prev_shape = data_dict[time_series_name][0].shape + if time_series.shape != prev_shape: + print( + f" - {timestamps.shape=} | {data_dict[timestamps_name][0].shape=}" + ) + print(f" - {time_series.shape=} | {prev_shape=}") + # set this stat as flattened + self._feature_is_flattened[timestamps_name] = True + self._feature_is_flattened[time_series_name] = True + # flatten corresponding stat + if time_series_name in self._stats: # TODO: ADD THIS IN TESTS + self._stats[time_series_name].flatten_array() + + if self._feature_is_flattened.get(time_series_name, False): + print(f" - {time_series_name} is flattened") + timestamps = timestamps.reshape((-1, 1)) + time_series = time_series.reshape((-1, 1)) + else: + print(f" - {time_series_name} is not flattened") + timestamps = timestamps.reshape((1, -1)) + time_series = time_series.reshape((1, -1)) + + data_dict[timestamps_name].append(timestamps) + data_dict[time_series_name].append(time_series) def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> None: """Process field data from a sample. @@ -414,26 +459,27 @@ def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> Non zone_name=zone_name, base_name=base_name, time=time, - ) + ).reshape((1, -1)) if field is not None: # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] if len( data_dict[stat_key] - ) > 0 and not self._field_is_flattened.get(stat_key, False): + ) > 0 and not self._feature_is_flattened.get( + stat_key, False + ): prev_shape = data_dict[stat_key][0].shape if field.shape != prev_shape: + print(f" - {field.shape=} | {prev_shape=}") # set this stat as flattened - self._field_is_flattened[stat_key] = True + self._feature_is_flattened[stat_key] = True # flatten corresponding stat if ( stat_key in self._stats ): # TODO: ADD THIS IN TESTS self._stats[stat_key].flatten_array() - if self._field_is_flattened.get(stat_key, False): + if self._feature_is_flattened.get(stat_key, False): field = field.reshape((-1, 1)) - else: - field = field.reshape((1, -1)) data_dict[stat_key].append(field) @@ -455,7 +501,7 @@ def _update_statistics(self, new_data: dict[str, list]) -> None: ): # TODO: ADD THIS IN TESTS list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) - if self._field_is_flattened.get(name, False): + if self._feature_is_flattened.get(name, False): # flatten all arrays in list_of_arrays n_samples = len(list_of_arrays) for i in range(len(list_of_arrays)): diff --git a/tests/utils/test_stats.py b/tests/utils/test_stats.py index 8004922..54714bd 100644 --- a/tests/utils/test_stats.py +++ b/tests/utils/test_stats.py @@ -78,6 +78,16 @@ def sample_with_field(np_samples_6): return s +@pytest.fixture() +def field_data(): + return np.random.randn(101) + + +@pytest.fixture() +def field_data_of_different_size(): + return np.random.randn(51) + + @pytest.fixture() def time_series_data(): # 10 time steps, 1 feature @@ -95,18 +105,26 @@ def time_series_data_of_different_size(): @pytest.fixture() -def sample_with_time_series(time_series_data): +def sample_with_time_series(time_series_data, field_data): s = Sample() times, values = time_series_data s.add_time_series("ts1", time_sequence=times, values=values) + s.init_base(1, 1) + s.init_zone(np.array([0, 0, 0])) + s.add_field(name="field1", field=field_data) return s @pytest.fixture() -def sample_with_time_series_of_different_size(time_series_data_of_different_size): +def sample_with_time_series_of_different_size( + time_series_data_of_different_size, field_data_of_different_size +): s = Sample() times, values = time_series_data_of_different_size s.add_time_series("ts1", time_sequence=times, values=values) + s.init_base(1, 1) + s.init_zone(np.array([0, 0, 0])) + s.add_field(name="field1", field=field_data_of_different_size) return s @@ -143,7 +161,20 @@ def test_add_samples_and_flatten(self, online_stats, np_samples_1, np_samples_2) def test_get_stats(self, online_stats, np_samples_1): online_stats.add_samples(np_samples_1) - online_stats.get_stats() + stats_dict = online_stats.get_stats() + # Check that all expected statistics keys are present + expected_keys = { + "mean", + "min", + "max", + "var", + "std", + "n_samples", + "n_points", + "n_features", + } + for key in expected_keys: + assert key in stats_dict, f"Missing key: {key}" def test_invalid_input_type(self, online_stats): with pytest.raises(TypeError): @@ -195,7 +226,35 @@ def test_add_dataset(self, stats, dataset): def test_get_stats(self, stats, samples): stats.add_samples(samples) - stats.get_stats() + stats_dict = stats.get_stats() + sample: Sample = samples[0] + feature_names = sample.get_scalar_names() + for base_name in sample.get_base_names(): + for zone_name in sample.get_zone_names(base_name=base_name): + for field_name in sample.get_field_names( + zone_name=zone_name, base_name=base_name + ): + feature_names.append(f"{base_name}/{zone_name}/{field_name}") + feature_names.extend(sample.get_time_series_names()) + for feat_name in feature_names: + assert feat_name in stats_dict, ( + f"Missing {feat_name=}, in {stats_dict.keys()}" + ) + # Check that all expected statistics keys are present + expected_keys = { + "mean", + "min", + "max", + "var", + "std", + "n_samples", + "n_points", + "n_features", + } + for key in expected_keys: + assert key in stats_dict[feat_name], ( + f" Missing {key=}, in {feat_name=}, in {stats_dict[feat_name].keys()}" + ) def test_invalid_input(self, stats): with pytest.raises(TypeError): @@ -226,45 +285,136 @@ def test_clear_statistics(self, stats, samples): stats.clear_statistics() assert len(stats.get_available_statistics()) == 0 - def test_add_samples_with_time_series( - self, stats, sample_with_time_series, sample_with_time_series_of_different_size + def test_add_samples_time_series_case_1(self, sample_with_time_series): + # 1st case: adding time series with same sizes with 2 calls to add_samples + stats1 = Stats() + stats1.add_samples([sample_with_time_series]) + stats1.add_samples([sample_with_time_series]) + keys = stats1.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats1._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 202 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 101) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats1._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 20 + stats_dict = stat.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 10) + + def test_add_samples_time_series_case_2( + self, sample_with_time_series, sample_with_time_series_of_different_size ): - stats.add_samples([sample_with_time_series]) - stats.add_samples([sample_with_time_series]) - keys = stats.get_available_statistics() + # 2nd case: adding time series with different sizes with 2 calls to add_samples + stats2 = Stats() + stats2.add_samples([sample_with_time_series]) + stats2.add_samples([sample_with_time_series_of_different_size]) + keys = stats2.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats2._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 152 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats2._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 15 + stats_dict = stat.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + + def test_add_samples_time_series_case_3(self, sample_with_time_series): + # 3rd case: adding time series with same sizes in a single call to add_samples, in empty stats + stats3 = Stats() + stats3.add_samples([sample_with_time_series, sample_with_time_series]) + keys = stats3.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats3._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 202 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 101) + assert "time_series/ts1" in keys assert "timestamps/ts1" in keys - stat = stats._stats["time_series/ts1"] + stat = stats3._stats["time_series/ts1"] + print(stat.get_stats()) assert stat.n_samples == 2 assert stat.n_points == 20 stats_dict = stat.get_stats() - assert "min" in stats_dict - assert "max" in stats_dict - assert "n_samples" in stats_dict assert "mean" in stats_dict - assert "var" in stats_dict - assert "std" in stats_dict - assert stats_dict["mean"].shape == ( - 1, - len(sample_with_time_series.get_time_series("ts1")[1]), + assert stats_dict["mean"].shape == (1, 10) + + def test_add_samples_time_series_case_4( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + # 4th case: adding time series with different sizes in a single call to add_samples, in empty stats + stats4 = Stats() + stats4.add_samples( + [sample_with_time_series, sample_with_time_series_of_different_size] ) - stats.add_samples([sample_with_time_series_of_different_size]) + keys = stats4.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats4._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 152 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + assert "time_series/ts1" in keys assert "timestamps/ts1" in keys - stat = stats._stats["time_series/ts1"] - # 2 samples of size 10 + 1 sample of size 5 -> 25 values + stat = stats4._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 15 + stats_dict = stat.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + + def test_add_samples_time_series_case_5( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + # 5th case: adding time series with different sizes in a single call to add_samples, in non-empty stats + stats5 = Stats() + stats5.add_samples([sample_with_time_series]) + stats5.add_samples( + [sample_with_time_series, sample_with_time_series_of_different_size] + ) + keys = stats5.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats5._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 3 + assert stat_field.n_points == 253 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats5._stats["time_series/ts1"] assert stat.n_samples == 3 assert stat.n_points == 25 stats_dict = stat.get_stats() - assert "min" in stats_dict - assert "max" in stats_dict - assert "n_samples" in stats_dict assert "mean" in stats_dict - assert "var" in stats_dict - assert "std" in stats_dict assert stats_dict["mean"].shape == (1, 1) - def test_merge_stats_with_time_series(self, sample_with_time_series): + def test_merge_stats_with_same_sizes(self, sample_with_time_series): stats1 = Stats() stats2 = Stats() stats1.add_samples([sample_with_time_series]) @@ -274,7 +424,50 @@ def test_merge_stats_with_time_series(self, sample_with_time_series): ) stats1.merge_stats(stats2) keys = stats1.get_available_statistics() + assert "Base_1_1/Zone/field1" in keys + + stat_field = stats1._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 202 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 101) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats1._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 20 + stats_dict = stat.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 10) + + def test_merge_stats_with_different_sizes( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + stats1 = Stats() + stats2 = Stats() + stats1.add_samples([sample_with_time_series]) + stats2.add_samples([sample_with_time_series_of_different_size]) + print( + f"{stats1._stats['time_series/ts1'].n_samples=}, {stats2._stats['time_series/ts1'].n_samples=}" + ) + stats1.merge_stats(stats2) + keys = stats1.get_available_statistics() + assert "Base_1_1/Zone/field1" in keys + + stat_field = stats1._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 152 + stats_dict = stat_field.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) + assert "time_series/ts1" in keys assert "timestamps/ts1" in keys stat = stats1._stats["time_series/ts1"] assert stat.n_samples == 2 + assert stat.n_points == 15 + stats_dict = stat.get_stats() + assert "mean" in stats_dict + assert stats_dict["mean"].shape == (1, 1) From 7881cda575be95878952b3590e7f0ca18be9ce73 Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Mon, 30 Jun 2025 22:45:58 +0200 Subject: [PATCH 09/11] (stats) fix wrong shape of stats + improve test --- examples/utils/stats_example.py | 3 - src/plaid/utils/stats.py | 48 +++++--------- tests/utils/test_stats.py | 107 ++++++++++++++++---------------- 3 files changed, 69 insertions(+), 89 deletions(-) diff --git a/examples/utils/stats_example.py b/examples/utils/stats_example.py index 1155f17..1b032e6 100644 --- a/examples/utils/stats_example.py +++ b/examples/utils/stats_example.py @@ -173,9 +173,6 @@ def sprint(stats: dict): zone_name="test_zone", base_name="test_base" ) -for sample in samples: - print(sample) -print(f"{len(samples)=}") stats.add_samples(samples) # %% [markdown] diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index dc80039..eeac7a9 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -203,8 +203,12 @@ def merge_stats(self, other: Self) -> None: "Shape mismatch in OnlineStatistics merging" ) - self.min = np.min(np.concatenate((self.min, other.min), axis=0), axis=0) - self.max = np.max(np.concatenate((self.max, other.max), axis=0), axis=0) + self.min = np.min( + np.concatenate((self.min, other.min), axis=0), axis=0, keepdims=True + ) + self.max = np.max( + np.concatenate((self.max, other.max), axis=0), axis=0, keepdims=True + ) self.n_points += other.n_points self.n_samples, self.mean, self.var = aggregate_stats( np.array([self.n_samples, other.n_samples]), @@ -215,8 +219,8 @@ def merge_stats(self, other: Self) -> None: def flatten_array(self) -> None: """When a shape incoherence is detected, you should call this function.""" - self.min = np.min(self.min, keepdims=True).reshape(-1, 1) - self.max = np.max(self.max, keepdims=True).reshape(-1, 1) + self.min = np.min(self.min, keepdims=True).reshape(1, 1) + self.max = np.max(self.max, keepdims=True).reshape(1, 1) self.n_points = self.n_samples * self.n_features assert self.mean.shape == self.var.shape self.n_points, self.mean, self.var = aggregate_stats( @@ -228,11 +232,12 @@ def flatten_array(self) -> None: self.n_features = 1 - def get_stats(self) -> dict[str, Union[int, np.ndarray[float]]]: + def get_stats(self) -> dict[str, Union[int, np.ndarray]]: """Get computed statistics. Returns: - dict[str,np.ndarray]: A dictionary containing computed statistics. + dict[str, Union[int, np.ndarray]]: A dictionary containing computed statistics. + The shapes of the arrays depend on the input data and may vary. """ return { "n_samples": self.n_samples, @@ -290,7 +295,6 @@ def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: new_data: dict[str, list] = {} for sample in samples: - print(f"{sample=}") # Process scalars self._process_scalar_data(sample, new_data) @@ -334,7 +338,8 @@ def get_stats( for stat_name, stat_value in ( self._stats[identifier].get_stats().items() ): - stats[identifier][stat_name] = np.squeeze(stat_value) + stats[identifier][stat_name] = stat_value + # stats[identifier][stat_name] = np.squeeze(stat_value) return stats @@ -386,52 +391,34 @@ def _process_time_series_data( data_dict (dict[str, list]): Dictionary to store processed data """ for name in sample.get_time_series_names(): - print(f" - {name=}") timestamps, time_series = sample.get_time_series(name) timestamps = timestamps.reshape((1, -1)) time_series = time_series.reshape((1, -1)) timestamps_name = f"timestamps/{name}" time_series_name = f"time_series/{name}" - print( - f" - {timestamps_name} is flattened -> {self._feature_is_flattened[timestamps_name] if timestamps_name in self._feature_is_flattened else None}" - ) - print( - f" - {time_series_name} is flattened -> {self._feature_is_flattened[time_series_name] if time_series_name in self._feature_is_flattened else None}" - ) if timestamps_name not in data_dict: - print(f" - {timestamps_name} was not in {data_dict.keys()=}") assert time_series_name not in data_dict data_dict[timestamps_name] = [] data_dict[time_series_name] = [] if timestamps is not None and time_series is not None: - print( - f" - add timestamps and time_series to data_dict, {data_dict.keys()=}" - ) - # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] if len( data_dict[time_series_name] ) > 0 and not self._feature_is_flattened.get(time_series_name, False): prev_shape = data_dict[time_series_name][0].shape if time_series.shape != prev_shape: - print( - f" - {timestamps.shape=} | {data_dict[timestamps_name][0].shape=}" - ) - print(f" - {time_series.shape=} | {prev_shape=}") # set this stat as flattened self._feature_is_flattened[timestamps_name] = True self._feature_is_flattened[time_series_name] = True # flatten corresponding stat - if time_series_name in self._stats: # TODO: ADD THIS IN TESTS + if time_series_name in self._stats: self._stats[time_series_name].flatten_array() if self._feature_is_flattened.get(time_series_name, False): - print(f" - {time_series_name} is flattened") timestamps = timestamps.reshape((-1, 1)) time_series = time_series.reshape((-1, 1)) else: - print(f" - {time_series_name} is not flattened") timestamps = timestamps.reshape((1, -1)) time_series = time_series.reshape((1, -1)) @@ -469,13 +456,10 @@ def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> Non ): prev_shape = data_dict[stat_key][0].shape if field.shape != prev_shape: - print(f" - {field.shape=} | {prev_shape=}") # set this stat as flattened self._feature_is_flattened[stat_key] = True # flatten corresponding stat - if ( - stat_key in self._stats - ): # TODO: ADD THIS IN TESTS + if stat_key in self._stats: self._stats[stat_key].flatten_array() if self._feature_is_flattened.get(stat_key, False): @@ -498,7 +482,7 @@ def _update_statistics(self, new_data: dict[str, list]) -> None: if ( isinstance(list_of_arrays[i], np.ndarray) and list_of_arrays[i].ndim == 1 - ): # TODO: ADD THIS IN TESTS + ): list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) if self._feature_is_flattened.get(name, False): diff --git a/tests/utils/test_stats.py b/tests/utils/test_stats.py index 54714bd..796263c 100644 --- a/tests/utils/test_stats.py +++ b/tests/utils/test_stats.py @@ -128,6 +128,37 @@ def sample_with_time_series_of_different_size( return s +# %% Functions + + +def check_stats_dict(stats_dict): + # Check that all expected statistics keys are present + expected_keys = [ + {"name": "mean", "type": np.ndarray, "ndim": 2}, + {"name": "min", "type": np.ndarray, "ndim": 2}, + {"name": "max", "type": np.ndarray, "ndim": 2}, + {"name": "var", "type": np.ndarray, "ndim": 2}, + {"name": "std", "type": np.ndarray, "ndim": 2}, + {"name": "n_samples", "type": (int, np.integer)}, + {"name": "n_points", "type": (int, np.integer)}, + {"name": "n_features", "type": (int, np.integer)}, + ] + for key_info in expected_keys: + key = key_info["name"] + assert key in stats_dict, f"Missing key: {key}" + if "type" in key_info: + assert isinstance(stats_dict[key], key_info["type"]), ( + f"Key '{key}' has wrong type: {type(stats_dict[key])}, expected {key_info['type']}" + ) + if "ndim" in key_info: + assert hasattr(stats_dict[key], "ndim"), ( + f"Key '{key}' does not have 'ndim' attribute" + ) + assert stats_dict[key].ndim == key_info["ndim"], ( + f"Key '{key}' has wrong ndim: {stats_dict[key].ndim}, expected {key_info['ndim']}" + ) + + # %% Tests @@ -163,18 +194,7 @@ def test_get_stats(self, online_stats, np_samples_1): online_stats.add_samples(np_samples_1) stats_dict = online_stats.get_stats() # Check that all expected statistics keys are present - expected_keys = { - "mean", - "min", - "max", - "var", - "std", - "n_samples", - "n_points", - "n_features", - } - for key in expected_keys: - assert key in stats_dict, f"Missing key: {key}" + check_stats_dict(stats_dict) def test_invalid_input_type(self, online_stats): with pytest.raises(TypeError): @@ -200,13 +220,9 @@ def test_merge_stats(self, np_samples_3, np_samples_4, np_samples_6): # do the merging stats1.merge_stats(stats2) assert stats1.n_samples == n_samples_before + stats2.n_samples - print(f"{n_samples_before=}, {n_samples_other=}") - print(f"{mean_before=}, {other_mean=}") expected_mean = ( mean_before * n_samples_before + other_mean * n_samples_other ) / (n_samples_before + n_samples_other) - print(f"{expected_mean=}") - print(f"{stats1.mean=}") assert np.allclose(stats1.mean, expected_mean) # other merging tests with pytest.raises(TypeError): @@ -227,34 +243,22 @@ def test_add_dataset(self, stats, dataset): def test_get_stats(self, stats, samples): stats.add_samples(samples) stats_dict = stats.get_stats() - sample: Sample = samples[0] + + sample = samples[0] feature_names = sample.get_scalar_names() + feature_names.extend(sample.get_time_series_names()) for base_name in sample.get_base_names(): for zone_name in sample.get_zone_names(base_name=base_name): for field_name in sample.get_field_names( zone_name=zone_name, base_name=base_name ): feature_names.append(f"{base_name}/{zone_name}/{field_name}") - feature_names.extend(sample.get_time_series_names()) + for feat_name in feature_names: assert feat_name in stats_dict, ( f"Missing {feat_name=}, in {stats_dict.keys()}" ) - # Check that all expected statistics keys are present - expected_keys = { - "mean", - "min", - "max", - "var", - "std", - "n_samples", - "n_points", - "n_features", - } - for key in expected_keys: - assert key in stats_dict[feat_name], ( - f" Missing {key=}, in {feat_name=}, in {stats_dict[feat_name].keys()}" - ) + check_stats_dict(stats_dict[feat_name]) def test_invalid_input(self, stats): with pytest.raises(TypeError): @@ -297,7 +301,7 @@ def test_add_samples_time_series_case_1(self, sample_with_time_series): assert stat_field.n_samples == 2 assert stat_field.n_points == 202 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 101) assert "time_series/ts1" in keys @@ -306,7 +310,7 @@ def test_add_samples_time_series_case_1(self, sample_with_time_series): assert stat.n_samples == 2 assert stat.n_points == 20 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 10) def test_add_samples_time_series_case_2( @@ -323,7 +327,7 @@ def test_add_samples_time_series_case_2( assert stat_field.n_samples == 2 assert stat_field.n_points == 152 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) assert "time_series/ts1" in keys @@ -332,7 +336,7 @@ def test_add_samples_time_series_case_2( assert stat.n_samples == 2 assert stat.n_points == 15 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) def test_add_samples_time_series_case_3(self, sample_with_time_series): @@ -346,17 +350,16 @@ def test_add_samples_time_series_case_3(self, sample_with_time_series): assert stat_field.n_samples == 2 assert stat_field.n_points == 202 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 101) assert "time_series/ts1" in keys assert "timestamps/ts1" in keys stat = stats3._stats["time_series/ts1"] - print(stat.get_stats()) assert stat.n_samples == 2 assert stat.n_points == 20 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 10) def test_add_samples_time_series_case_4( @@ -374,7 +377,7 @@ def test_add_samples_time_series_case_4( assert stat_field.n_samples == 2 assert stat_field.n_points == 152 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) assert "time_series/ts1" in keys @@ -383,7 +386,7 @@ def test_add_samples_time_series_case_4( assert stat.n_samples == 2 assert stat.n_points == 15 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) def test_add_samples_time_series_case_5( @@ -402,7 +405,7 @@ def test_add_samples_time_series_case_5( assert stat_field.n_samples == 3 assert stat_field.n_points == 253 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) assert "time_series/ts1" in keys @@ -411,7 +414,7 @@ def test_add_samples_time_series_case_5( assert stat.n_samples == 3 assert stat.n_points == 25 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) def test_merge_stats_with_same_sizes(self, sample_with_time_series): @@ -419,9 +422,6 @@ def test_merge_stats_with_same_sizes(self, sample_with_time_series): stats2 = Stats() stats1.add_samples([sample_with_time_series]) stats2.add_samples([sample_with_time_series]) - print( - f"{stats1._stats['time_series/ts1'].n_samples=}, {stats2._stats['time_series/ts1'].n_samples=}" - ) stats1.merge_stats(stats2) keys = stats1.get_available_statistics() assert "Base_1_1/Zone/field1" in keys @@ -429,8 +429,9 @@ def test_merge_stats_with_same_sizes(self, sample_with_time_series): stat_field = stats1._stats["Base_1_1/Zone/field1"] assert stat_field.n_samples == 2 assert stat_field.n_points == 202 + assert stat_field.n_features == 101 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 101) assert "time_series/ts1" in keys @@ -438,8 +439,9 @@ def test_merge_stats_with_same_sizes(self, sample_with_time_series): stat = stats1._stats["time_series/ts1"] assert stat.n_samples == 2 assert stat.n_points == 20 + assert stat.n_features == 10 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 10) def test_merge_stats_with_different_sizes( @@ -449,9 +451,6 @@ def test_merge_stats_with_different_sizes( stats2 = Stats() stats1.add_samples([sample_with_time_series]) stats2.add_samples([sample_with_time_series_of_different_size]) - print( - f"{stats1._stats['time_series/ts1'].n_samples=}, {stats2._stats['time_series/ts1'].n_samples=}" - ) stats1.merge_stats(stats2) keys = stats1.get_available_statistics() assert "Base_1_1/Zone/field1" in keys @@ -460,7 +459,7 @@ def test_merge_stats_with_different_sizes( assert stat_field.n_samples == 2 assert stat_field.n_points == 152 stats_dict = stat_field.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) assert "time_series/ts1" in keys @@ -469,5 +468,5 @@ def test_merge_stats_with_different_sizes( assert stat.n_samples == 2 assert stat.n_points == 15 stats_dict = stat.get_stats() - assert "mean" in stats_dict + check_stats_dict(stats_dict) assert stats_dict["mean"].shape == (1, 1) From 298de716e8d8571e57a0620f85fcabdff12e3582 Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Tue, 1 Jul 2025 09:36:01 +0200 Subject: [PATCH 10/11] (stats) fix internal array dim check --- src/plaid/utils/stats.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index eeac7a9..0ee243a 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -102,7 +102,6 @@ def add_samples(self, x: np.ndarray, n_samples: int = None) -> None: n_samples (int, optional): The number of samples in the input array. If not provided, it will be inferred from the shape of `x`. Use this argument when the input array has already been flattened because of shape inconsistencies. Raises: - ShapeError: Raised when there is an inconsistency in the shape of the input array. ValueError: Raised when input contains NaN or Inf values. """ # Validate input @@ -478,12 +477,12 @@ def _update_statistics(self, new_data: dict[str, list]) -> None: if name not in self._stats: self._stats[name] = OnlineStatistics() - for i in range(len(list_of_arrays)): - if ( - isinstance(list_of_arrays[i], np.ndarray) - and list_of_arrays[i].ndim == 1 - ): - list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) + # internal check, should never happen if self._process_* functions work correctly + for sample_id in range(len(list_of_arrays)): + assert isinstance(list_of_arrays[sample_id], np.ndarray) + assert list_of_arrays[sample_id].ndim == 2, ( + f"for feature <{name}> -> {sample_id=}: {list_of_arrays[sample_id].ndim=} should be 2" + ) if self._feature_is_flattened.get(name, False): # flatten all arrays in list_of_arrays From 7fee15ba6eb7396b24301a5ea86fef5494435513 Mon Sep 17 00:00:00 2001 From: Xavier Roynard Date: Tue, 1 Jul 2025 10:04:24 +0200 Subject: [PATCH 11/11] (stats) update docstring --- src/plaid/utils/stats.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index 0ee243a..4de2861 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -276,8 +276,11 @@ def add_dataset(self, dset: Dataset) -> None: def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: """Add samples or a dataset to compute statistics for. - Processes both scalar and field data from the provided samples, - computing running statistics for each data identifier. + Compute stats for each features present in the samples among scalars, fields and time_series. + For fields and time_series, as long as the added samples have the same shape as the existing ones, + the stats will be computed per-coordinates (n_features=x.shape[-1]). + But as soon as the shapes differ, the stats and added fields/time_series will be flattened (n_features=1), + then stats will be computed over all values of the field/time_series. Args: samples (Union[List[Sample], Dataset]): List of samples or dataset to process