diff --git a/examples/utils/stats_example.py b/examples/utils/stats_example.py index a270102..1b032e6 100644 --- a/examples/utils/stats_example.py +++ b/examples/utils/stats_example.py @@ -33,8 +33,8 @@ # %% def sprint(stats: dict): print("Stats:") - for k in stats: - print(" - {} -> {}".format(k, stats[k])) + for k,v in stats.items(): + print(f" - {k:>9} -> {v}") # %% [markdown] @@ -58,7 +58,8 @@ def sprint(stats: dict): # %% # First batch of samples first_batch_samples = 3.0 * np.random.randn(100, 3) + 10.0 -print(f"{first_batch_samples.shape = }") +print() +print(f"=== {first_batch_samples.shape = }") stats_computer.add_samples(first_batch_samples) stats = stats_computer.get_stats() @@ -67,7 +68,8 @@ def sprint(stats: dict): # %% second_batch_samples = 10.0 * np.random.randn(1000, 3) - 1.0 -print(f"{second_batch_samples.shape = }") +print() +print(f"=== {second_batch_samples.shape = }") stats_computer.add_samples(second_batch_samples) stats = stats_computer.get_stats() @@ -79,7 +81,8 @@ def sprint(stats: dict): # %% total_samples = np.concatenate((first_batch_samples, second_batch_samples), axis=0) -print(f"{total_samples.shape = }") +print() +print(f"=== {total_samples.shape = }") new_stats_computer = OnlineStatistics() new_stats_computer.add_samples(total_samples) @@ -110,14 +113,23 @@ def sprint(stats: dict): nb_samples = 11 samples = [Sample() for _ in range(nb_samples)] -spatial_shape_max = 20 +spatial_shape_max = 5 # for sample in samples: sample.add_scalar("test_scalar", np.random.randn()) - sample.init_base(2, 3, "test_base") + sample.add_scalar("test_ND_scalar", np.random.randn(3)) + sample.init_base(2, 3,) zone_shape = np.array([0, 0, 0]) - sample.init_zone(zone_shape, zone_name="test_zone") + sample.init_zone(zone_shape) sample.add_field("test_field", np.random.randn(spatial_shape_max)) + sample.init_zone(zone_shape, zone_name="test_zone_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2') + sample.init_base(2, 3, "test_base_2") + zone_shape = np.array([0, 0, 0]) + sample.init_zone(zone_shape, zone_name="test_zone_1", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), base_name="test_base_2") + sample.init_zone(zone_shape, zone_name="test_zone_2", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2', base_name="test_base_2") stats.add_samples(samples) @@ -132,19 +144,33 @@ def sprint(stats: dict): # ### Feed Stats with more Samples # %% -nb_samples = 11 -spatial_shape_max = 20 +nb_samples = 13 +# spatial_shape_max = 20 samples = [Sample() for _ in range(nb_samples)] for sample in samples: sample.add_scalar("test_scalar", np.random.randn()) + sample.init_base(2, 3,) + zone_shape = np.array([0, 0, 0]) + sample.init_zone(zone_shape) + sample.add_field("test_field", np.random.randn(spatial_shape_max)) + sample.init_zone(zone_shape, zone_name="test_zone_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2') + sample.init_base(2, 3, "test_base_2") + zone_shape = np.array([0, 0, 0]) + sample.init_zone(zone_shape, zone_name="test_zone_1", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), base_name="test_base_2") + sample.init_zone(zone_shape, zone_name="test_zone_2", base_name="test_base_2") + sample.add_field("test_field", np.random.randn(spatial_shape_max), zone_name='test_zone_2', base_name="test_base_2") + sample.init_base(2, 3, "test_base") zone_shape = np.array([0, 0, 0]) - sample.init_zone(zone_shape, zone_name="test_zone") - sample.add_field("test_field_same_size", np.random.randn(7)) + sample.init_zone(zone_shape, zone_name="test_zone", base_name="test_base") + sample.add_field("test_field_same_size", np.random.randn(spatial_shape_max), zone_name="test_zone", base_name="test_base") sample.add_field( "test_field", np.random.randn(np.random.randint(spatial_shape_max // 2, spatial_shape_max)), + zone_name="test_zone", base_name="test_base" ) stats.add_samples(samples) diff --git a/src/plaid/utils/stats.py b/src/plaid/utils/stats.py index 9685328..4de2861 100644 --- a/src/plaid/utils/stats.py +++ b/src/plaid/utils/stats.py @@ -9,14 +9,19 @@ # %% Imports +import copy import logging -from typing import Union +from typing import List, Union + +try: # pragma: no cover + from typing import Self +except ImportError: # pragma: no cover + from typing import Any as Self import numpy as np from plaid.containers.dataset import Dataset from plaid.containers.sample import Sample -from plaid.utils.base import ShapeError logger = logging.getLogger(__name__) logging.basicConfig( @@ -34,12 +39,12 @@ def aggregate_stats( This function calculates aggregated statistics, such as the total number of samples, mean, and variance, by taking into account the statistics computed for each batch of data. - cf: https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques) + cf: [Variance from (cardinal,mean,variance) of several statistical series](https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques)#Formules) Args: - sizes (np.ndarray): An array containing the sizes (number of samples) of each batch. - means (np.ndarray): An array containing the means of each batch. - vars (np.ndarray): An array containing the variances of each batch. + sizes (np.ndarray): An array containing the sizes (number of samples) of each batch. Expect shape (n_batches,1). + means (np.ndarray): An array containing the means of each batch. Expect shape (n_batches, n_features). + vars (np.ndarray): An array containing the variances of each batch. Expect shape (n_batches, n_features). Returns: tuple[np.ndarray,np.ndarray,np.ndarray]: A tuple containing the aggregated statistics in the following order: @@ -47,10 +52,15 @@ def aggregate_stats( - Weighted mean calculated from the batch means. - Weighted variance calculated from the batch variances, considering the means. """ - total_n_samples = np.sum(sizes, keepdims=True) - total_mean = np.sum(sizes * means, keepdims=True) / total_n_samples + assert sizes.ndim == 1 + assert means.ndim == 2 + assert len(sizes) == len(means) + assert means.shape == vars.shape + sizes = sizes.reshape((-1, 1)) + total_n_samples = np.sum(sizes) + total_mean = np.sum(sizes * means, axis=0, keepdims=True) / total_n_samples total_var = ( - np.sum(sizes * (vars + (total_mean - means) ** 2), keepdims=True) + np.sum(sizes * (vars + (total_mean - means) ** 2), axis=0, keepdims=True) / total_n_samples ) return total_n_samples, total_mean, total_var @@ -60,44 +70,76 @@ def aggregate_stats( class OnlineStatistics(object): - """OnlineStatistics is a class for computing online statistics (e.g., min, max, mean, variance, and standard deviation) of numpy arrays.""" + """OnlineStatistics is a class for computing online statistics of numpy arrays. + + This class computes running statistics (min, max, mean, variance, std) for streaming data + without storing all samples in memory. + + Example: + >>> stats = OnlineStatistics() + >>> stats.add_samples(np.array([[1, 2], [3, 4]])) + >>> stats.add_samples(np.array([[5, 6]])) + >>> print(stats.get_stats()['mean']) + [[3. 4.]] + """ def __init__(self) -> None: """Initialize an empty OnlineStatistics object.""" self.n_samples: int = 0 + self.n_features: int = None + self.n_points: int = None self.min: np.ndarray = None self.max: np.ndarray = None self.mean: np.ndarray = None self.var: np.ndarray = None self.std: np.ndarray = None - def add_samples(self, x: np.ndarray) -> None: + def add_samples(self, x: np.ndarray, n_samples: int = None) -> None: """Add samples to compute statistics for. Args: - x (np.ndarray): The input numpy array containing samples data. + x (np.ndarray): The input numpy array containing samples data. Expect 2D arrays with shape (n_samples, n_features). + n_samples (int, optional): The number of samples in the input array. If not provided, it will be inferred from the shape of `x`. Use this argument when the input array has already been flattened because of shape inconsistencies. Raises: - ShapeError: Raised when there is an inconsistency in the shape of the input array. + ValueError: Raised when input contains NaN or Inf values. """ + # Validate input + if not isinstance(x, np.ndarray): + raise TypeError("Input must be a numpy array") + + if np.any(~np.isfinite(x)): + raise ValueError("Input contains NaN or Inf values") + + # Handle 1D arrays if x.ndim == 1: if self.min is not None: - if self.min.size == 1: - # n_samples x 1 + if self.min.shape[1] == 1: x = x.reshape((-1, 1)) else: - # 1 x n_features x = x.reshape((1, -1)) - else: # pragma: no cover - raise ShapeError( - "can't determine if input array with ndim=1, is 1 x n_features or n_samples x 1" - ) + else: + x = x.reshape((-1, 1)) # Default to column vector + + # Handle n-dimensional arrays elif x.ndim > 2: - # suppose last dim is features dim, all previous dims are space - # dims and are aggregated + # if we have array of shape (n_samples, n_points, n_features) + # it will be reshaped to (n_samples * n_points, n_features) x = x.reshape((-1, x.shape[-1])) - added_n_samples = len(x) + if self.n_features is None: + self.n_features = x.shape[1] + + if x.shape[1] != self.n_features: + # it means that stats where previously on a per-point mode, + # but it is no longer possible as the new added samples have a different shape + # so we need to shift the stats to a per-sample mode, and then flatten the stats array + self.flatten_array() + n_samples = x.shape[0] + x = x.reshape((-1, 1)) + + added_n_samples = len(x) if n_samples is None else n_samples + added_n_points = x.size added_min = np.min(x, axis=0, keepdims=True) added_max = np.max(x, axis=0, keepdims=True) added_mean = np.mean(x, axis=0, keepdims=True) @@ -111,53 +153,95 @@ def add_samples(self, x: np.ndarray) -> None: or (self.var is None) ): self.n_samples = added_n_samples + self.n_points = added_n_points self.min = added_min self.max = added_max self.mean = added_mean self.var = added_var else: - self.min = np.min(np.concatenate((self.min, added_min), axis=0), axis=0) - self.max = np.max(np.concatenate((self.max, added_max), axis=0), axis=0) - # new_n_samples = self.n_samples + added_n_samples - # new_mean = ( - # self.n_samples * self.mean + added_n_samples * added_mean - # ) / new_n_samples - self.n_samples, self.mean, self.var = aggregate_stats( - np.concatenate( - [ - self.n_samples + np.zeros(self.mean.shape, dtype=int), - added_n_samples + np.zeros(added_mean.shape, dtype=int), - ] - ), - np.concatenate([self.mean, added_mean]), - np.concatenate([self.var, added_var]), + self.min = np.min( + np.concatenate((self.min, added_min), axis=0), axis=0, keepdims=True ) + self.max = np.max( + np.concatenate((self.max, added_max), axis=0), axis=0, keepdims=True + ) + if self.n_features > 1: + # feature not flattened, we are on a per-sample mode + self.n_points += added_n_points + self.n_samples, self.mean, self.var = aggregate_stats( + np.array([self.n_samples, added_n_samples]), + np.concatenate([self.mean, added_mean]), + np.concatenate([self.var, added_var]), + ) + else: + # feature flattened, we are on a per-point mode + self.n_samples += added_n_samples + self.n_points, self.mean, self.var = aggregate_stats( + np.array([self.n_points, added_n_points]), + np.concatenate([self.mean, added_mean]), + np.concatenate([self.var, added_var]), + ) + + self.std = np.sqrt(self.var) - # # cf: https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques) - # self.var = (self.n_samples * (self.var + (new_mean - self.mean)**2) + added_n_samples*(added_var + (new_mean - added_mean)**2)) / new_n_samples - # self.n_samples = new_n_samples - # self.mean = new_mean + def merge_stats(self, other: Self) -> None: + """Merge statistics from another instance. + Args: + other (Self): The other instance to merge statistics from. + """ + if not isinstance(other, self.__class__): + raise TypeError("Can only merge with another instance of the same class") + + if self.n_features != other.n_features: + # flatten both + self.flatten_array() + other = copy.deepcopy(other) + other.flatten_array() + assert self.min.shape == other.min.shape, ( + "Shape mismatch in OnlineStatistics merging" + ) + + self.min = np.min( + np.concatenate((self.min, other.min), axis=0), axis=0, keepdims=True + ) + self.max = np.max( + np.concatenate((self.max, other.max), axis=0), axis=0, keepdims=True + ) + self.n_points += other.n_points + self.n_samples, self.mean, self.var = aggregate_stats( + np.array([self.n_samples, other.n_samples]), + np.concatenate([self.mean, other.mean]), + np.concatenate([self.var, other.var]), + ) self.std = np.sqrt(self.var) def flatten_array(self) -> None: """When a shape incoherence is detected, you should call this function.""" - self.min = np.min(self.min, keepdims=True) - self.max = np.max(self.max, keepdims=True) + self.min = np.min(self.min, keepdims=True).reshape(1, 1) + self.max = np.max(self.max, keepdims=True).reshape(1, 1) + self.n_points = self.n_samples * self.n_features assert self.mean.shape == self.var.shape - self.n_samples, self.mean, self.var = aggregate_stats( - np.zeros(self.mean.shape, dtype=int) + self.n_samples, self.mean, self.var + self.n_points, self.mean, self.var = aggregate_stats( + np.array([self.n_samples] * self.n_features), + self.mean.reshape(-1, 1), + self.var.reshape(-1, 1), ) self.std = np.sqrt(self.var) - def get_stats(self) -> dict[str, np.ndarray]: + self.n_features = 1 + + def get_stats(self) -> dict[str, Union[int, np.ndarray]]: """Get computed statistics. Returns: - dict[str,np.ndarray]: A dictionary containing computed statistics. + dict[str, Union[int, np.ndarray]]: A dictionary containing computed statistics. + The shapes of the arrays depend on the input data and may vary. """ return { "n_samples": self.n_samples, + "n_points": self.n_points, + "n_features": self.n_features, "min": self.min, "max": self.max, "mean": self.mean, @@ -166,12 +250,20 @@ def get_stats(self) -> dict[str, np.ndarray]: } -class Stats(object): - """Stats is a class for aggregating and computing statistics for datasets.""" +class Stats: + """Class for aggregating and computing statistics across datasets. + + The Stats class processes both scalar and field data from samples or datasets, + computing running statistics like min, max, mean, variance and standard deviation. + + Attributes: + _stats (dict[str, OnlineStatistics]): Dictionary mapping data identifiers to their statistics + """ def __init__(self): """Initialize an empty Stats object.""" - self._stats = {} + self._stats: dict[str, OnlineStatistics] = {} + self._feature_is_flattened: dict[str, bool] = {} def add_dataset(self, dset: Dataset) -> None: """Add a dataset to compute statistics for. @@ -181,26 +273,38 @@ def add_dataset(self, dset: Dataset) -> None: """ self.add_samples(dset) - def add_samples(self, samples: Union[list[Sample], Dataset]) -> None: - """Add samples (or a dataset) to compute statistics for. + def add_samples(self, samples: Union[List[Sample], Dataset]) -> None: + """Add samples or a dataset to compute statistics for. + + Compute stats for each features present in the samples among scalars, fields and time_series. + For fields and time_series, as long as the added samples have the same shape as the existing ones, + the stats will be computed per-coordinates (n_features=x.shape[-1]). + But as soon as the shapes differ, the stats and added fields/time_series will be flattened (n_features=1), + then stats will be computed over all values of the field/time_series. Args: - samples (Union[list[Sample],Dataset]): The list of samples or a dataset to add. + samples (Union[List[Sample], Dataset]): List of samples or dataset to process + + Raises: + TypeError: If samples is not a List[Sample] or Dataset + ValueError: If a sample contains invalid data """ - # ---# Aggregate - new_data = {} + # Input validation + if not isinstance(samples, (list, Dataset)): + raise TypeError("samples must be a List[Sample] or Dataset") + + # Process each sample + new_data: dict[str, list] = {} + for sample in samples: - # ---# Scalars - for s_name in sample.get_scalar_names(): - if s_name not in new_data: - new_data[s_name] = [] - new_data[s_name].append(sample.get_scalar(s_name)) + # Process scalars + self._process_scalar_data(sample, new_data) - # ---# Fields - # TODO + # Process fields + self._process_field_data(sample, new_data) - # ---# Categorical - # TODO + # ---# Time-Series + self._process_time_series_data(sample, new_data) # ---# SpatialSupport (Meshes) # TODO @@ -208,43 +312,218 @@ def add_samples(self, samples: Union[list[Sample], Dataset]) -> None: # ---# TemporalSupport # TODO - # ---# Process - for name in new_data: - # new_shapes = [value.shape for value in new_data[name] if value.shape!=new_data[name][0].shape] - # has_same_shape = (len(new_shapes)==0) - has_same_shape = True - - if has_same_shape: - new_data[name] = np.array(new_data[name]) - else: # pragma: no cover ### remove "no cover" when "has_same_shape = True" is no longer used - if name in self._stats: - self._stats[name].flatten_array() - new_data[name] = np.concatenate( - [np.ravel(value) for value in new_data[name]] - ) - - if new_data[name].ndim == 1: - new_data[name] = new_data[name].reshape((-1, 1)) + # ---# Categorical + # TODO - if name not in self._stats: - self._stats[name] = OnlineStatistics() + # Update statistics + self._update_statistics(new_data) - self._stats[name].add_samples(new_data[name]) + def get_stats( + self, identifiers: list[str] = None + ) -> dict[str, dict[str, np.ndarray]]: + """Get computed statistics for specified data identifiers. - def get_stats(self) -> dict[str, dict[str, np.ndarray]]: - """Get computed statistics for different data identifiers. + Args: + identifiers (list[str], optional): List of data identifiers to retrieve. + If None, returns statistics for all identifiers. Returns: - dict[str,dict[str,np.ndarray]]: A dictionary containing computed statistics for different data identifiers. + dict[str, dict[str, np.ndarray]]: Dictionary mapping identifiers to their statistics """ + if identifiers is None: + identifiers = self.get_available_statistics() + stats = {} - for identifier in self._stats: - stats[identifier] = {} - for stat_name, stat_value in self._stats[identifier].get_stats().items(): - stats[identifier][stat_name] = np.squeeze(stat_value) + for identifier in identifiers: + if identifier in self._stats: + stats[identifier] = {} + for stat_name, stat_value in ( + self._stats[identifier].get_stats().items() + ): + stats[identifier][stat_name] = stat_value + # stats[identifier][stat_name] = np.squeeze(stat_value) return stats + def get_available_statistics(self) -> list[str]: + """Get list of data identifiers with computed statistics. + + Returns: + list[str]: List of data identifiers + """ + return sorted(self._stats.keys()) + + def clear_statistics(self) -> None: + """Clear all computed statistics.""" + self._stats.clear() + + def merge_stats(self, other: Self) -> None: + """Merge statistics from another Stats object. + + Args: + other (Stats): Stats object to merge with + """ + for name, stats in other._stats.items(): + if name not in self._stats: + self._stats[name] = copy.deepcopy(stats) + else: + self._stats[name].merge_stats(stats) + + def _process_scalar_data(self, sample: Sample, data_dict: dict[str, list]) -> None: + """Process scalar data from a sample. + + Args: + sample (Sample): Sample containing scalar data + data_dict (dict[str, list]): Dictionary to store processed data + """ + for name in sample.get_scalar_names(): + if name not in data_dict: + data_dict[name] = [] + value = sample.get_scalar(name) + if value is not None: + data_dict[name].append(np.array(value).reshape((1, -1))) + + def _process_time_series_data( + self, sample: Sample, data_dict: dict[str, list] + ) -> None: + """Process time series data from a sample. + + Args: + sample (Sample): Sample containing time series data + data_dict (dict[str, list]): Dictionary to store processed data + """ + for name in sample.get_time_series_names(): + timestamps, time_series = sample.get_time_series(name) + timestamps = timestamps.reshape((1, -1)) + time_series = time_series.reshape((1, -1)) + + timestamps_name = f"timestamps/{name}" + time_series_name = f"time_series/{name}" + if timestamps_name not in data_dict: + assert time_series_name not in data_dict + data_dict[timestamps_name] = [] + data_dict[time_series_name] = [] + if timestamps is not None and time_series is not None: + # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] + if len( + data_dict[time_series_name] + ) > 0 and not self._feature_is_flattened.get(time_series_name, False): + prev_shape = data_dict[time_series_name][0].shape + if time_series.shape != prev_shape: + # set this stat as flattened + self._feature_is_flattened[timestamps_name] = True + self._feature_is_flattened[time_series_name] = True + # flatten corresponding stat + if time_series_name in self._stats: + self._stats[time_series_name].flatten_array() + + if self._feature_is_flattened.get(time_series_name, False): + timestamps = timestamps.reshape((-1, 1)) + time_series = time_series.reshape((-1, 1)) + else: + timestamps = timestamps.reshape((1, -1)) + time_series = time_series.reshape((1, -1)) + + data_dict[timestamps_name].append(timestamps) + data_dict[time_series_name].append(time_series) + + def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> None: + """Process field data from a sample. + + Args: + sample (Sample): Sample containing field data + data_dict (dict[str, list]): Dictionary to store processed data + """ + for time in sample.get_all_mesh_times(): + for base_name in sample.get_base_names(time=time): + for zone_name in sample.get_zone_names(base_name=base_name, time=time): + for field_name in sample.get_field_names( + zone_name=zone_name, base_name=base_name, time=time + ): + stat_key = f"{base_name}/{zone_name}/{field_name}" + if stat_key not in data_dict: + data_dict[stat_key] = [] + field = sample.get_field( + field_name, + zone_name=zone_name, + base_name=base_name, + time=time, + ).reshape((1, -1)) + if field is not None: + # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] + if len( + data_dict[stat_key] + ) > 0 and not self._feature_is_flattened.get( + stat_key, False + ): + prev_shape = data_dict[stat_key][0].shape + if field.shape != prev_shape: + # set this stat as flattened + self._feature_is_flattened[stat_key] = True + # flatten corresponding stat + if stat_key in self._stats: + self._stats[stat_key].flatten_array() + + if self._feature_is_flattened.get(stat_key, False): + field = field.reshape((-1, 1)) + + data_dict[stat_key].append(field) + + def _update_statistics(self, new_data: dict[str, list]) -> None: + """Update running statistics with new data. + + Args: + new_data (dict[str, list]): Dictionary containing new data to process + """ + for name, list_of_arrays in new_data.items(): + if len(list_of_arrays) > 0: + if name not in self._stats: + self._stats[name] = OnlineStatistics() + + # internal check, should never happen if self._process_* functions work correctly + for sample_id in range(len(list_of_arrays)): + assert isinstance(list_of_arrays[sample_id], np.ndarray) + assert list_of_arrays[sample_id].ndim == 2, ( + f"for feature <{name}> -> {sample_id=}: {list_of_arrays[sample_id].ndim=} should be 2" + ) + + if self._feature_is_flattened.get(name, False): + # flatten all arrays in list_of_arrays + n_samples = len(list_of_arrays) + for i in range(len(list_of_arrays)): + list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) + else: + n_samples = None + + # Convert to numpy array and reshape if needed + data = np.concatenate(list_of_arrays) + assert data.ndim == 2 + + self._stats[name].add_samples(data, n_samples=n_samples) + + # # old version of _update_statistics logic + # for name in new_data: + # # new_shapes = [value.shape for value in new_data[name] if value.shape!=new_data[name][0].shape] + # # has_same_shape = (len(new_shapes)==0) + # has_same_shape = True + + # if has_same_shape: + # new_data[name] = np.array(new_data[name]) + # else: # pragma: no cover ### remove "no cover" when "has_same_shape = True" is no longer used + # if name in self._stats: + # self._stats[name].flatten_array() + # new_data[name] = np.concatenate( + # [np.ravel(value) for value in new_data[name]] + # ) + + # if new_data[name].ndim == 1: + # new_data[name] = new_data[name].reshape((-1, 1)) + + # if name not in self._stats: + # self._stats[name] = OnlineStatistics() + + # self._stats[name].add_samples(new_data[name]) + # TODO : FAIRE DEUX FONCTIONS : # - compute_stats(samples) -> stats # - aggregate_stats(list[stats]) diff --git a/tests/utils/test_stats.py b/tests/utils/test_stats.py index 576c99c..796263c 100644 --- a/tests/utils/test_stats.py +++ b/tests/utils/test_stats.py @@ -10,6 +10,7 @@ import numpy as np import pytest +from plaid.containers.sample import Sample from plaid.utils.stats import OnlineStatistics, Stats # %% Fixtures @@ -40,6 +41,11 @@ def np_samples_5(): return np.random.randn(400) +@pytest.fixture() +def np_samples_6(): + return np.random.randn(50, 1) + + @pytest.fixture() def online_stats(): return OnlineStatistics() @@ -50,6 +56,109 @@ def stats(): return Stats() +@pytest.fixture() +def sample_with_scalar(np_samples_3): + s = Sample() + s.add_scalar("foo", float(np_samples_3.mean())) + return s + + +@pytest.fixture() +def sample_with_field(np_samples_6): + s = Sample() + # 1. Initialize the CGNS tree + s.init_tree() + # 2. Create a base and a zone + s.init_base(topological_dim=3, physical_dim=3) + s.init_zone(zone_shape=np.array([np_samples_6.shape[0], 0, 0])) + # 3. Set node coordinates (required for a valid zone) + s.set_nodes(np.zeros((np_samples_6.shape[0], 3))) + # 4. Add a field named "bar" + s.add_field(name="bar", field=np_samples_6) + return s + + +@pytest.fixture() +def field_data(): + return np.random.randn(101) + + +@pytest.fixture() +def field_data_of_different_size(): + return np.random.randn(51) + + +@pytest.fixture() +def time_series_data(): + # 10 time steps, 1 feature + times = np.linspace(0, 1, 10) + values = np.random.randn(10) + return times, values + + +@pytest.fixture() +def time_series_data_of_different_size(): + # 5 time steps, 1 feature + times = np.linspace(0, 1, 5) + values = np.random.randn(5) + return times, values + + +@pytest.fixture() +def sample_with_time_series(time_series_data, field_data): + s = Sample() + times, values = time_series_data + s.add_time_series("ts1", time_sequence=times, values=values) + s.init_base(1, 1) + s.init_zone(np.array([0, 0, 0])) + s.add_field(name="field1", field=field_data) + return s + + +@pytest.fixture() +def sample_with_time_series_of_different_size( + time_series_data_of_different_size, field_data_of_different_size +): + s = Sample() + times, values = time_series_data_of_different_size + s.add_time_series("ts1", time_sequence=times, values=values) + s.init_base(1, 1) + s.init_zone(np.array([0, 0, 0])) + s.add_field(name="field1", field=field_data_of_different_size) + return s + + +# %% Functions + + +def check_stats_dict(stats_dict): + # Check that all expected statistics keys are present + expected_keys = [ + {"name": "mean", "type": np.ndarray, "ndim": 2}, + {"name": "min", "type": np.ndarray, "ndim": 2}, + {"name": "max", "type": np.ndarray, "ndim": 2}, + {"name": "var", "type": np.ndarray, "ndim": 2}, + {"name": "std", "type": np.ndarray, "ndim": 2}, + {"name": "n_samples", "type": (int, np.integer)}, + {"name": "n_points", "type": (int, np.integer)}, + {"name": "n_features", "type": (int, np.integer)}, + ] + for key_info in expected_keys: + key = key_info["name"] + assert key in stats_dict, f"Missing key: {key}" + if "type" in key_info: + assert isinstance(stats_dict[key], key_info["type"]), ( + f"Key '{key}' has wrong type: {type(stats_dict[key])}, expected {key_info['type']}" + ) + if "ndim" in key_info: + assert hasattr(stats_dict[key], "ndim"), ( + f"Key '{key}' does not have 'ndim' attribute" + ) + assert stats_dict[key].ndim == key_info["ndim"], ( + f"Key '{key}' has wrong ndim: {stats_dict[key].ndim}, expected {key_info['ndim']}" + ) + + # %% Tests @@ -69,6 +178,9 @@ def test_add_samples_3(self, online_stats, np_samples_3, np_samples_5): online_stats.min = np_samples_3 online_stats.add_samples(np_samples_5) + def test_add_samples_4(self, online_stats, np_samples_5): + online_stats.add_samples(np_samples_5) + def test_add_samples_already_present(self, online_stats, np_samples_1): online_stats.add_samples(np_samples_1) online_stats.add_samples(np_samples_1) @@ -80,7 +192,42 @@ def test_add_samples_and_flatten(self, online_stats, np_samples_1, np_samples_2) def test_get_stats(self, online_stats, np_samples_1): online_stats.add_samples(np_samples_1) - online_stats.get_stats() + stats_dict = online_stats.get_stats() + # Check that all expected statistics keys are present + check_stats_dict(stats_dict) + + def test_invalid_input_type(self, online_stats): + with pytest.raises(TypeError): + online_stats.add_samples([1, 2, 3]) # List instead of ndarray + + def test_nan_inf_input(self, online_stats): + with pytest.raises(ValueError): + online_stats.add_samples(np.array([1, np.nan, 3])) + with pytest.raises(ValueError): + online_stats.add_samples(np.array([1, np.inf, 3])) + + def test_merge_stats(self, np_samples_3, np_samples_4, np_samples_6): + stats1 = OnlineStatistics() + stats2 = OnlineStatistics() + stats1.add_samples(np_samples_3) + stats2.add_samples(np_samples_6) + n_samples_before = stats1.n_samples + n_samples_other = stats2.n_samples + mean_before = stats1.mean.copy() + other_mean = stats2.mean.copy() + stats3 = OnlineStatistics() + stats3.add_samples(np_samples_4) + # do the merging + stats1.merge_stats(stats2) + assert stats1.n_samples == n_samples_before + stats2.n_samples + expected_mean = ( + mean_before * n_samples_before + other_mean * n_samples_other + ) / (n_samples_before + n_samples_other) + assert np.allclose(stats1.mean, expected_mean) + # other merging tests + with pytest.raises(TypeError): + stats1.merge_stats(0.0) + stats1.merge_stats(stats3) class Test_Stats: @@ -95,4 +242,231 @@ def test_add_dataset(self, stats, dataset): def test_get_stats(self, stats, samples): stats.add_samples(samples) - stats.get_stats() + stats_dict = stats.get_stats() + + sample = samples[0] + feature_names = sample.get_scalar_names() + feature_names.extend(sample.get_time_series_names()) + for base_name in sample.get_base_names(): + for zone_name in sample.get_zone_names(base_name=base_name): + for field_name in sample.get_field_names( + zone_name=zone_name, base_name=base_name + ): + feature_names.append(f"{base_name}/{zone_name}/{field_name}") + + for feat_name in feature_names: + assert feat_name in stats_dict, ( + f"Missing {feat_name=}, in {stats_dict.keys()}" + ) + check_stats_dict(stats_dict[feat_name]) + + def test_invalid_input(self, stats): + with pytest.raises(TypeError): + stats.add_samples("invalid") + + def test_empty_samples(self, stats): + stats.add_samples([]) + assert len(stats.get_available_statistics()) == 0 + + def test_merge_stats(self, sample_with_scalar, sample_with_field): + # Create two Stats objects with different samples + stats1 = Stats() + stats2 = Stats() + stats1.add_samples([sample_with_scalar]) + stats2.add_samples([sample_with_field]) + # Merge stats2 into stats1 + stats1.merge_stats(stats2) + # Both keys should be present + keys = stats1.get_available_statistics() + assert "foo" in keys or "bar" in keys + # Check that statistics are present for merged keys + for key in keys: + s = stats1._stats[key] + assert s.n_samples > 0 + + def test_clear_statistics(self, stats, samples): + stats.add_samples(samples) + stats.clear_statistics() + assert len(stats.get_available_statistics()) == 0 + + def test_add_samples_time_series_case_1(self, sample_with_time_series): + # 1st case: adding time series with same sizes with 2 calls to add_samples + stats1 = Stats() + stats1.add_samples([sample_with_time_series]) + stats1.add_samples([sample_with_time_series]) + keys = stats1.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats1._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 202 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 101) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats1._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 20 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 10) + + def test_add_samples_time_series_case_2( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + # 2nd case: adding time series with different sizes with 2 calls to add_samples + stats2 = Stats() + stats2.add_samples([sample_with_time_series]) + stats2.add_samples([sample_with_time_series_of_different_size]) + keys = stats2.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats2._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 152 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats2._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 15 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + def test_add_samples_time_series_case_3(self, sample_with_time_series): + # 3rd case: adding time series with same sizes in a single call to add_samples, in empty stats + stats3 = Stats() + stats3.add_samples([sample_with_time_series, sample_with_time_series]) + keys = stats3.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats3._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 202 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 101) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats3._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 20 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 10) + + def test_add_samples_time_series_case_4( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + # 4th case: adding time series with different sizes in a single call to add_samples, in empty stats + stats4 = Stats() + stats4.add_samples( + [sample_with_time_series, sample_with_time_series_of_different_size] + ) + keys = stats4.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats4._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 152 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats4._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 15 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + def test_add_samples_time_series_case_5( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + # 5th case: adding time series with different sizes in a single call to add_samples, in non-empty stats + stats5 = Stats() + stats5.add_samples([sample_with_time_series]) + stats5.add_samples( + [sample_with_time_series, sample_with_time_series_of_different_size] + ) + keys = stats5.get_available_statistics() + + assert "Base_1_1/Zone/field1" in keys + stat_field = stats5._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 3 + assert stat_field.n_points == 253 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats5._stats["time_series/ts1"] + assert stat.n_samples == 3 + assert stat.n_points == 25 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + def test_merge_stats_with_same_sizes(self, sample_with_time_series): + stats1 = Stats() + stats2 = Stats() + stats1.add_samples([sample_with_time_series]) + stats2.add_samples([sample_with_time_series]) + stats1.merge_stats(stats2) + keys = stats1.get_available_statistics() + assert "Base_1_1/Zone/field1" in keys + + stat_field = stats1._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 202 + assert stat_field.n_features == 101 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 101) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats1._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 20 + assert stat.n_features == 10 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 10) + + def test_merge_stats_with_different_sizes( + self, sample_with_time_series, sample_with_time_series_of_different_size + ): + stats1 = Stats() + stats2 = Stats() + stats1.add_samples([sample_with_time_series]) + stats2.add_samples([sample_with_time_series_of_different_size]) + stats1.merge_stats(stats2) + keys = stats1.get_available_statistics() + assert "Base_1_1/Zone/field1" in keys + + stat_field = stats1._stats["Base_1_1/Zone/field1"] + assert stat_field.n_samples == 2 + assert stat_field.n_points == 152 + stats_dict = stat_field.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1) + + assert "time_series/ts1" in keys + assert "timestamps/ts1" in keys + stat = stats1._stats["time_series/ts1"] + assert stat.n_samples == 2 + assert stat.n_points == 15 + stats_dict = stat.get_stats() + check_stats_dict(stats_dict) + assert stats_dict["mean"].shape == (1, 1)