In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
regressor
\n",
+ "
Pipeline(step...ptimizer=3))])
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
transformer
\n",
+ "
Pipeline(step...mponents=9))])
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
func
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
inverse_func
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
check_inverse
\n",
+ "
False
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
transformers
\n",
+ "
[('pca', ...)]
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
remainder
\n",
+ "
'passthrough'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
sparse_threshold
\n",
+ "
0.3
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
n_jobs
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
transformer_weights
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
verbose
\n",
+ "
False
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
verbose_feature_names_out
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
force_int_remainder_cols
\n",
+ "
'deprecated'
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
[0, 1, 2, 3, 4, 5, 6, 7]
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
n_components
\n",
+ "
8
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
copy
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
whiten
\n",
+ "
False
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
svd_solver
\n",
+ "
'auto'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
tol
\n",
+ "
0.0
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
iterated_power
\n",
+ "
'auto'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
n_oversamples
\n",
+ "
10
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
power_iteration_normalizer
\n",
+ "
'auto'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
random_state
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
[8, 9]
passthrough
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
copy
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
with_mean
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
with_std
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
kernel
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
alpha
\n",
+ "
1e-10
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
optimizer
\n",
+ "
'fmin_l_bfgs_b'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
n_restarts_optimizer
\n",
+ "
3
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
normalize_y
\n",
+ "
False
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
copy_X_train
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
n_targets
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
random_state
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
feature_range
\n",
+ "
(0, ...)
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
copy
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
clip
\n",
+ "
False
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " Parameters\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
n_components
\n",
+ "
9
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
copy
\n",
+ "
True
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
whiten
\n",
+ "
False
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
svd_solver
\n",
+ "
'auto'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
tol
\n",
+ "
0.0
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
iterated_power
\n",
+ "
'auto'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
n_oversamples
\n",
+ "
10
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
power_iteration_normalizer
\n",
+ "
'auto'
\n",
+ "
\n",
+ " \n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
random_state
\n",
+ "
None
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "TransformedTargetRegressor(check_inverse=False,\n",
+ " regressor=Pipeline(steps=[('preprocessor',\n",
+ " ColumnTransformer(remainder='passthrough',\n",
+ " transformers=[('pca',\n",
+ " PCA(n_components=8),\n",
+ " [0,\n",
+ " 1,\n",
+ " 2,\n",
+ " 3,\n",
+ " 4,\n",
+ " 5,\n",
+ " 6,\n",
+ " 7])])),\n",
+ " ('scaler',\n",
+ " StandardScaler()),\n",
+ " ('regressor',\n",
+ " GaussianProcessRegressor(n_restarts_optimizer=3))]),\n",
+ " transformer=Pipeline(steps=[('scaler',\n",
+ " MinMaxScaler()),\n",
+ " ('pca',\n",
+ " PCA(n_components=9))]))"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.fit(X,y)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "plaid_dev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/plaid/containers/dataset.py b/src/plaid/containers/dataset.py
index 8641f69..fe3a553 100644
--- a/src/plaid/containers/dataset.py
+++ b/src/plaid/containers/dataset.py
@@ -17,6 +17,7 @@
Self = TypeVar("Self")
+import copy
import logging
import os
import shutil
@@ -139,7 +140,7 @@ def get_samples(
return {id: self._samples[id] for id in ids}
def add_sample(self, sample: Sample, id: int = None) -> int:
- """Add a new :class:`Sample ` to the :class:`Dataset .`.
+ """Add a new :class:`Sample ` to the :class:`Dataset `.
Args:
sample (Sample): The sample to add.
@@ -351,6 +352,28 @@ def get_scalar_names(self, ids: list[int] = None) -> list[str]:
scalars_names.sort()
return scalars_names
+ # -------------------------------------------------------------------------#
+ def get_time_series_names(self, ids: list[int] = None) -> list[str]:
+ """Return union of time series names in all samples with id in ids.
+
+ Args:
+ ids (list[int], optional): Select time_series depending on sample id. If None, take all samples. Defaults to None.
+
+ Returns:
+ list[str]: List of all time_series names
+ """
+ if ids is not None and len(set(ids)) != len(ids):
+ logger.warning("Provided ids are not unique")
+
+ time_series_names = []
+ for sample in self.get_samples(ids, as_list=True):
+ s_names = sample.get_time_series_names()
+ for s_name in s_names:
+ if s_name not in time_series_names:
+ time_series_names.append(s_name)
+ time_series_names.sort()
+ return time_series_names
+
# -------------------------------------------------------------------------#
def get_field_names(
self, ids: list[int] = None, zone_name: str = None, base_name: str = None
@@ -449,7 +472,9 @@ def get_scalars_to_tabular(
named_tabular = {}
for s_name in scalar_names:
- res = np.empty(nb_samples)
+ first_scalar = self[sample_ids[0]].get_scalar(s_name)
+ s_dtype = first_scalar.dtype if first_scalar is not None else None
+ res = np.empty(nb_samples, dtype=s_dtype)
res.fill(None)
for i_, id in enumerate(sample_ids):
val = self[id].get_scalar(s_name)
@@ -458,9 +483,287 @@ def get_scalars_to_tabular(
named_tabular[s_name] = res
if as_nparray:
- named_tabular = np.array(list(named_tabular.values())).T
+ named_tabular = np.array(
+ [named_tabular[s_name] for s_name in scalar_names]
+ ).T
+ return named_tabular
+
+ # -------------------------------------------------------------------------#
+ def add_tabular_time_series(
+ self, tabular: np.ndarray, names: list[str] = None
+ ) -> None:
+ """Add tabular time_series data to the summary.
+
+ Args:
+ tabular (np.ndarray): A 2D NumPy array containing tabular time_series data, with shape (n_samples, max_n_timestamps, n_features, 2). Last dimension should contain (timestamp, value) pairs. Timestamps should be in strictly increasing order, and padded with decreasing values so that all features have same length.
+ names (list[str], optional): A list of column names for the tabular data. Defaults to None.
+
+ Raises:
+ ShapeError: Raised if the input tabular array does not have the correct shape (2D).
+ ShapeError: Raised if the number of columns in the tabular data does not match the number of names provided.
+
+ Note:
+ If no names are provided, it will automatically create names based on the pattern 'X{number}'
+ """
+ nb_samples = len(tabular)
+
+ if tabular.ndim != 4:
+ raise ShapeError(f"{tabular.ndim=}!=4, should be == 4")
+ if names is None:
+ names = [f"X{i}" for i in range(tabular.shape[2])]
+ if tabular.shape[2] != len(names):
+ raise ShapeError(
+ f"tabular's 3rd dimension should have same size as there are names, but {tabular.shape[2]=} and {len(names)=}"
+ )
+
+ # ---# For efficiency, first add values to storage
+ name_to_ids = {}
+ for i_col, name in enumerate(names):
+ name_to_ids[name] = tabular[:, :, i_col]
+
+ # ---# Then add data in sample
+ for i_samp in range(nb_samples):
+ sample = Sample()
+ for i_name, name in enumerate(names):
+ timestamps = name_to_ids[name][i_samp, :, 0]
+ values = name_to_ids[name][i_samp, :, 1]
+ # detect first decreasing timestamp
+ diff_ts = timestamps[1:] - timestamps[:-1]
+ decreasing_ts_ids = (diff_ts <= 0).nonzero()[0]
+ if len(decreasing_ts_ids) > 0:
+ first_decreasing_ts_id = decreasing_ts_ids[0]
+ else:
+ first_decreasing_ts_id = len(timestamps) - 1
+ # add only values and timestamps up to the first decreasing timestamp
+ sample.add_time_series(
+ name,
+ timestamps[: first_decreasing_ts_id + 1],
+ values[: first_decreasing_ts_id + 1],
+ )
+ self.add_sample(sample)
+
+ def get_time_series_to_tabular(
+ self,
+ time_series_names: list[str] = None,
+ sample_ids: list[int] = None,
+ as_nparray=False,
+ ) -> Union[dict[str, np.ndarray], np.ndarray]:
+ """Return a dict containing time_series values as tabulars/arrays.
+
+ Args:
+ time_series_names (str, optional): time_series to work on. If None, all time_series will be returned. Defaults to None.
+ sample_ids (list[int], optional): Filter by sample id. If None, take all samples. Defaults to None.
+ as_nparray (bool, optional): If True, return the data as a single numpy ndarray. If False, return a dictionary mapping time_series names to their respective tabular values. Defaults to False.
+
+ Returns:
+ np.ndarray: if as_nparray is True.
+ dict[str,np.ndarray]: if as_nparray is False, time_series name -> tabular values.
+
+ Note:
+ Unlike the `get_field_to_tabular` method, this method does not require the all time_series in all samples to have the same length.
+ It will pad shorter time_series the following way:
+ - timestamps with the last known timestamp - 1, to trigger decreasing behavior
+ - values with NaNs
+ It is the users' responsibility to retrieve the data in a way that is compatible with this padding strategy.
+ """
+ if time_series_names is None:
+ time_series_names = self.get_time_series_names(sample_ids)
+ elif len(set(time_series_names)) != len(time_series_names):
+ logger.warning("Provided time_series names are not unique")
+
+ if sample_ids is None:
+ sample_ids = self.get_sample_ids()
+ elif len(set(sample_ids)) != len(sample_ids):
+ logger.warning("Provided sample ids are not unique")
+ nb_samples = len(sample_ids)
+
+ named_tabular = {}
+ for s_name in time_series_names:
+ _, first_time_series = self[sample_ids[0]].get_time_series(s_name)
+ t_dtype = first_time_series.dtype if first_time_series is not None else None
+ ts_list, val_list = [], []
+ for i_, id in enumerate(sample_ids):
+ ts, val = self[id].get_time_series(s_name)
+ ts_list.append(ts)
+ val_list.append(val)
+ max_n_timestamps = max(len(ts) for ts in ts_list)
+ res = np.empty((nb_samples, max_n_timestamps, 2), dtype=t_dtype)
+ res.fill(None)
+ for i_, (ts, val) in enumerate(zip(ts_list, val_list)):
+ t_len = len(ts)
+ if ts is not None:
+ res[i_, :t_len, 0] = ts
+ if t_len < max_n_timestamps:
+ res[i_, t_len:, 0] = ts[-1] - 1
+ if val is not None:
+ res[i_, :t_len, 1] = val
+ named_tabular[s_name] = res
+
+ if as_nparray:
+ res_list = [named_tabular[s_name] for s_name in time_series_names]
+ max_len = max(arr.shape[1] for arr in res_list)
+ # pad time_series
+ for i, arr in enumerate(res_list):
+ if arr.shape[0] < max_len:
+ res_list[i] = np.pad(
+ arr,
+ (0, max_len - arr.shape[0]),
+ mode="constant",
+ constant_values=arr[-1] - 1,
+ )
+ named_tabular = np.array(res_list).T
+ return named_tabular
+
+ # -------------------------------------------------------------------------#
+ def add_tabular_fields(self, tabular: np.ndarray, names: list[str] = None) -> None:
+ """Add tabular field data to the dataset.
+
+ Args:
+ tabular (np.ndarray): A 2D NumPy array containing tabular field data.
+ names (list[str], optional): A list of column names for the tabular data. Defaults to None.
+
+ Raises:
+ ShapeError: Raised if the input tabular array does not have the correct shape (2D).
+ ShapeError: Raised if the number of columns in the tabular data does not match the number of names provided.
+
+ Note:
+ If no names are provided, it will automatically create names based on the pattern 'X{number}'
+ """
+ nb_samples = len(tabular)
+
+ if tabular.ndim != 2:
+ raise ShapeError(f"{tabular.ndim=}!=2, should be == 2")
+ if names is None:
+ names = [f"X{i}" for i in range(tabular.shape[1])]
+ if tabular.shape[1] != len(names):
+ raise ShapeError(
+ f"tabular should have as many columns as there are names, but {tabular.shape[1]=} and {len(names)=}"
+ )
+
+ # ---# For efficiency, first add values to storage
+ name_to_ids = {}
+ for col, name in zip(tabular.T, names):
+ name_to_ids[name] = col
+
+ # ---# Then add data in sample
+ for i_samp in range(nb_samples):
+ sample = Sample()
+ for name in names:
+ sample.add_field(name, name_to_ids[name][i_samp])
+ self.add_sample(sample)
+
+ def get_fields_to_tabular(
+ self,
+ field_names: list[str] = None,
+ sample_ids: list[int] = None,
+ as_nparray=False,
+ ) -> Union[dict[str, np.ndarray], np.ndarray]:
+ """Return a dict containing field values as tabulars/arrays.
+
+ Args:
+ field_names (str, optional): fields to work on. If None, all fields will be returned. Defaults to None.
+ sample_ids (list[int], optional): Filter by sample id. If None, take all samples. Defaults to None.
+ as_nparray (bool, optional): If True, return the data as a single numpy ndarray. If False, return a dictionary mapping field names to their respective tabular values. Defaults to False.
+
+ Returns:
+ np.ndarray: if as_nparray is True.
+ dict[str,np.ndarray]: if as_nparray is False, field name -> tabular values.
+
+ Note:
+ This method won’t work if the fields does not have the same sizes in all samples specified by `sample_ids`.
+ """
+ if field_names is None:
+ field_names = self.get_field_names(sample_ids)
+ elif len(set(field_names)) != len(field_names):
+ logger.warning("Provided field names are not unique")
+
+ if sample_ids is None:
+ sample_ids = self.get_sample_ids()
+ elif len(set(sample_ids)) != len(sample_ids):
+ logger.warning("Provided sample ids are not unique")
+ nb_samples = len(sample_ids)
+
+ named_tabular = {}
+ for f_name in field_names:
+ first_field = self[sample_ids[0]].get_field(f_name)
+ if first_field is not None:
+ f_dtype = first_field.dtype
+ nb_points = first_field.shape[0]
+ if len(first_field.shape) == 1:
+ field_size = 1
+ elif len(first_field.shape) == 2:
+ field_size = first_field.shape[1]
+ else:
+ raise ShapeError(
+ f"Expects field as a 2-dim array, but field {f_name} from sample {sample_ids[0]} has shape: {first_field.shape}"
+ )
+ else:
+ print("---")
+ print(f"Field {f_name} of sample {sample_ids[0]} is None")
+ print("---")
+ res = np.empty((nb_samples, nb_points, field_size), dtype=f_dtype)
+ # print(f"{nb_points=}")
+ # print(f"{field_size=}")
+ # print(f"{res.shape=}")
+ res.fill(None)
+ for i_, id in enumerate(sample_ids):
+ val = self[id].get_field(f_name)
+ # print(f"{val.shape=}")
+ if val is not None:
+ if not (val.shape[0] == nb_points):
+ # TODO: explain error
+ raise ShapeError("")
+ if len(val.shape) == 2 and not (val.shape[1] == field_size):
+ # TODO: explain error
+ raise ShapeError("")
+ res[i_] = val.reshape((nb_points, field_size))
+ named_tabular[f_name] = res
+
+ if as_nparray:
+ all_tabs = [named_tabular[f_name] for f_name in field_names]
+ if all([t.shape[1] == all_tabs[0].shape[2] for t in all_tabs]):
+ named_tabular = np.stack(all_tabs, axis=2)
+ else:
+ named_tabular = np.concatenate(all_tabs, axis=2)
return named_tabular
+ # -------------------------------------------------------------------------#
+ def extract_dataset(
+ self,
+ scalars: list[str] = [],
+ fields: list[str] = [],
+ time_series: list[str] = [],
+ ) -> Self:
+ """Extract a subset of the dataset containing only the specified scalars, fields, and time series.
+
+ Args:
+ scalars (list[str], optional): List of scalar names to include. Defaults to [].
+ fields (list[str], optional): List of field names to include. Defaults to [].
+ time_series (list[str], optional): List of time series names to include. Defaults to [].
+
+ Returns:
+ Self: A new dataset containing only the specified scalars, fields, and time series.
+ """
+ dataset = Dataset()
+
+ for id, sample in self.get_samples().items():
+ new_sample = Sample()
+
+ for scalar_name in scalars:
+ new_sample.add_scalar(scalar_name, sample.get_scalar(scalar_name))
+ for time_series_name in time_series:
+ new_sample.add_time_series(
+ time_series_name, sample.get_time_series(time_series_name)
+ )
+ # TODO: extract only specified fields --> WON’T WORK: there is no Base/Zone specified
+ # TODO: use field names of type '//' with optional zone/base names
+ for field_name in fields:
+ new_sample.add_field(field_name, sample.get_field(field_name))
+
+ dataset.add_sample(new_sample, id)
+
+ return dataset
+
# -------------------------------------------------------------------------#
def add_info(self, cat_key: str, info_key: str, info: str) -> None:
"""Add information to the :class:`Dataset `, overwriting existing information if there's a conflict.
@@ -605,7 +908,7 @@ def get_infos(self) -> dict[str, dict[str, str]]:
return self._infos
def print_infos(self) -> None:
- """Prints information in a readable format (pretty print)."""
+ """Print information in a readable format (pretty print)."""
infos_cats = list(self._infos.keys())
tf = "*********************** \x1b[34;1mdataset infos\x1b[0m **********************\n"
for cat in infos_cats:
@@ -624,7 +927,7 @@ def print_infos(self) -> None:
# -------------------------------------------------------------------------#
def merge_dataset(self, dataset: Self) -> list[int]:
- """Merges another Dataset into this one.
+ """Merge another Dataset into this one.
Args:
dataset (Dataset): The data set to be merged into this one (self).
@@ -641,9 +944,42 @@ def merge_dataset(self, dataset: Self) -> list[int]:
raise ValueError("dataset must be an instance of Dataset")
return self.add_samples(dataset.get_samples(as_list=True))
+ def merge_samples(self, dataset: Self) -> list[int]:
+ """Merge Samples of another dataset into samples of this one.
+
+ Args:
+ dataset (Self): The data set whom samples will be merged into those of this one (self).
+
+ Returns:
+ list[int]: ids of added :class:`Samples ` from input :class:`Dataset ` that were not already present in this dataset (self).
+
+ Raises:
+ ValueError: If the provided dataset value is not an instance of Dataset
+ """
+ if not isinstance(dataset, Dataset):
+ raise ValueError("dataset must be an instance of Dataset")
+ trg_samples = self.get_samples()
+ new_ids = []
+ for samp_id, samp in dataset.get_samples().items():
+ if samp_id in trg_samples:
+ for scalar_name in samp.get_scalar_names():
+ trg_samples[samp_id].add_scalar(
+ scalar_name, samp.get_scalar(scalar_name)
+ )
+ for time_series_name in samp.get_time_series_names():
+ trg_samples[samp_id].add_time_series(
+ time_series_name, samp.get_time_series(time_series_name)
+ )
+ trg_samples[samp_id].add_tree(samp.get_tree())
+ else:
+ # TODO: should we copy the sample before adding it ?
+ self.add_sample(copy.deepcopy(samp), id=samp_id)
+ new_ids.append(samp_id)
+ return new_ids
+
# -------------------------------------------------------------------------#
def save(self, fname: Union[str, Path]) -> None:
- """Saves the data set to a TAR (Tape Archive) file.
+ """Save the data set to a TAR (Tape Archive) file.
It creates a temporary intermediate directory to store temporary files during the loading process.
@@ -1002,11 +1338,11 @@ def __len__(self) -> int:
"""
return len(self._samples)
- def __getitem__(self, id: int) -> Sample:
+ def __getitem__(self, id: Union[int, slice]) -> Union[Sample, Self]:
"""Retrieve a specific sample by its ID int this dataset.
Args:
- id (int): The ID of the sample to retrieve.
+ id (Union[int,slice]): The ID of the sample to retrieve.
Raises:
IndexError: If the provided ID is out of bounds or does not exist in the dataset.
@@ -1024,12 +1360,28 @@ def __getitem__(self, id: int) -> Sample:
Seealso:
This function can also be called using `__call__()`.
"""
- if id in self._samples:
- return self._samples[id]
+ if isinstance(id, (int, np.integer)):
+ if id in self._samples:
+ return self._samples[id]
+ else:
+ raise IndexError(
+ f"sample with {id=} not set -> use 'Dataset.add_sample' or 'Dataset.add_samples'"
+ )
else:
- raise IndexError(
- f"sample with {id=} not set -> use 'Dataset.add_sample' or 'Dataset.add_samples'"
- )
+ if isinstance(id, slice):
+ # TODO: check slice.stop is positive, if negative use len(dataset)+slice.stop
+ ids = np.arange(slice.start, slice.stop, slice.step)
+ else:
+ raise TypeError(
+ f"Unsupported index type: {type(id)}, should be int or slice"
+ )
+ samples = []
+ for id in ids:
+ if id in self._samples:
+ samples.append(self._samples[id])
+ dset = Dataset()
+ dset.add_samples(samples)
+ return dset
__call__ = __getitem__
diff --git a/src/plaid/containers/sample.py b/src/plaid/containers/sample.py
index 838c413..8654762 100644
--- a/src/plaid/containers/sample.py
+++ b/src/plaid/containers/sample.py
@@ -131,7 +131,7 @@ def read_index_range(pyTree: list, dim: list[int]):
class Sample(BaseModel):
- """Represents a single sample. It contains data and information related to a single observation or measurement within a dataset."""
+ """Represent a single sample. It contains data and information related to a single observation or measurement within a dataset."""
def __init__(
self,
diff --git a/src/plaid/wrappers/__init__.py b/src/plaid/wrappers/__init__.py
new file mode 100644
index 0000000..dbcab68
--- /dev/null
+++ b/src/plaid/wrappers/__init__.py
@@ -0,0 +1,7 @@
+"""Wrapper functions for the PLAID library."""
+
+# -*- coding: utf-8 -*-
+#
+# This file is subject to the terms and conditions defined in
+# file 'LICENSE.txt', which is part of this source code package.
+#
diff --git a/src/plaid/wrappers/sklearn.py b/src/plaid/wrappers/sklearn.py
new file mode 100644
index 0000000..6c54c07
--- /dev/null
+++ b/src/plaid/wrappers/sklearn.py
@@ -0,0 +1,496 @@
+# -*- coding: utf-8 -*-
+#
+# This file is subject to the terms and conditions defined in
+# file 'LICENSE.txt', which is part of this source code package.
+#
+#
+"""This module provides wrappers for scikit-learn estimators and transformers so they can be used seamlessly in scikit-learn Pipelines
+with PLAID objects. The wrapped blocks (e.g. PCA, GaussianProcessRegressor, StandardScaler, etc.) take a `plaid.containers.Dataset` as input,
+and return a `plaid.containers.Dataset` as output. This allows you to build
+scikit-learn Pipelines where all blocks operate on PLAID objects, enabling end-to-end workflows with domain-specific data structures.
+
+Example usage:
+
+ from sklearn.pipeline import Pipeline
+ from plaid.wrappers.sklearn import WrappedSklearnTransform, WrappedSklearnRegressor
+ from sklearn.decomposition import PCA
+ from sklearn.gaussian_process import GaussianProcessRegressor
+ from plaid.containers.dataset import Dataset
+
+ # Define your PLAID dataset
+ dataset = Dataset(...)
+
+ # Build a pipeline with wrapped sklearn blocks
+ pipe = Pipeline([
+ ("pca", WrappedSklearnTransform(PCA(n_components=2))),
+ ("reg", WrappedSklearnRegressor(GaussianProcessRegressor()))
+ ])
+
+ # Fit the pipeline (all steps receive and return Dataset objects)
+ pipe.fit(dataset)
+ # Predict
+ y_pred = pipe.predict(dataset)
+
+All wrapped blocks must accept and return PLAID Dataset objects.
+
+Some inspiration come from [TensorDict](https://pytorch.org/tensordict/stable/reference/generated/tensordict.nn.TensorDictModule.html#tensordict.nn.TensorDictModule).
+
+This module defines the following classes:
+`PlaidWrapper`: Base class for scikit-learn estimators and transformers to operate on PLAID objects.
+├── `WrappedSklearnTransform`: Wrapper for scikit-learn Transformer blocks.
+└── `WrappedSklearnPredictor`: Wrapper for scikit-learn Predictor blocks.
+ ├── `WrappedSklearnClassifier`: Wrapper for scikit-learn Classifier blocks.
+ └── `WrappedSklearnRegressor`: Wrapper for scikit-learn Regressor blocks.
+"""
+
+# %% Imports
+import sys
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else: # pragma: no cover
+ from typing import TypeVar
+
+ Self = TypeVar("Self")
+
+import logging
+from copy import copy
+from typing import Union
+
+import numpy as np
+from sklearn.base import (
+ BaseEstimator,
+ BiclusterMixin,
+ ClassifierMixin,
+ ClusterMixin,
+ DensityMixin,
+ MetaEstimatorMixin,
+ MultiOutputMixin,
+ OutlierMixin,
+ RegressorMixin,
+ TransformerMixin,
+)
+
+from plaid.containers.dataset import Dataset
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+ format="[%(asctime)s:%(levelname)s:%(filename)s:%(funcName)s(%(lineno)d)]:%(message)s",
+ level=logging.INFO,
+)
+
+# %% Classes
+
+SklearnBlock = Union[
+ BaseEstimator,
+ TransformerMixin,
+ RegressorMixin,
+ ClassifierMixin,
+ ClusterMixin,
+ BiclusterMixin,
+ DensityMixin,
+ OutlierMixin,
+ MultiOutputMixin,
+]
+"""Union type for all scikit-learn blocks that can be used in a Pipeline."""
+
+
+class PlaidWrapper(BaseEstimator, MetaEstimatorMixin):
+ """Base wrapper for scikit-learn estimators and transformers to operate on PLAID objects.
+
+ This class is not intended to be used directly, but as a base for wrappers that allow scikit-learn blocks
+ (such as PCA, StandardScaler, GaussianProcessRegressor, etc.) to be used in sklearn Pipelines with PLAID objects.
+ All methods accept and return `plaid.containers.Dataset` objects.
+ """
+
+ def __init__(
+ self,
+ sklearn_block: SklearnBlock,
+ fit_only_ones: bool = True,
+ in_features: Union[list[str], str] = [],
+ out_features: Union[list[str], str] = [],
+ ):
+ """Wrap a scikit-learn estimator or transformer.
+
+ Args:
+ sklearn_block (SklearnBlock): Any scikit-learn transform or predictor (e.g. PCA, StandardScaler, GPRegressor).
+ fit_only_ones (bool, optional): If True, the model will only be fitted once. Defaults to True.
+ in_features (Union[list[str],str], optional):
+ Names of scalars and/or fields to take as input.
+ Scalar (resp. field) names should be given as 'scalar::' (resp. 'field::').
+ Use 'all' to use all available scalars and fields, or 'scalar::all'/'field::all' for all scalars/fields.
+ Defaults to [].
+ out_features (Union[list[str],str], optional): Names of scalars and/or fields to take as output, using the same convention as for `in_features`. Defaults to [].
+ Additionally, if 'same', 'scalar::same' or 'field::same' is given, it will use as output the same names as for input.
+ """
+ # TODO: check https://scikit-learn.org/stable/developers/develop.html#instantiation
+ self.sklearn_block = sklearn_block
+ self.fit_only_ones = fit_only_ones
+ self.in_features = copy(in_features)
+ self.out_features = copy(out_features)
+
+ def fit(self, dataset: Dataset, **kwargs):
+ """Fit the wrapped scikit-learn model on a PLAID dataset.
+
+ Args:
+ dataset (Dataset): The dataset to fit the model on.
+
+ Returns:
+ self: Returns self for chaining.
+ """
+ if self.fit_only_ones and self.__sklearn_is_fitted__():
+ return self
+
+ self._determine_input_output_names(dataset)
+
+ X, y = self._extract_X_y_from_plaid(dataset)
+ self.sklearn_block.fit(X, y, **kwargs)
+
+ self._is_fitted = True
+ return self
+
+ def _determine_input_output_names(self, dataset: Dataset):
+ """Determine the input/output names based on the in_features and out_features."""
+ self.in_features = (
+ self.in_features
+ if isinstance(self.in_features, list)
+ else [self.in_features]
+ )
+ self.out_features = (
+ self.out_features
+ if isinstance(self.out_features, list)
+ else [self.out_features]
+ )
+ self._determine_scalar_names(dataset)
+ self._determine_time_series_names(dataset)
+ self._determine_field_names(dataset)
+
+ def _determine_scalar_names(self, dataset: Dataset):
+ """Determine the input/output scalar names based on the in_features and out_features."""
+ # Input scalars
+ if ("all" in self.in_features) or ("scalar::all" in self.in_features):
+ self.input_scalars = dataset.get_scalar_names()
+ else:
+ self.input_scalars = [
+ s[8:] for s in self.in_features if s[:8] == "scalar::"
+ ]
+
+ # Output scalars
+ if ("all" in self.out_features) or ("scalar::all" in self.out_features):
+ assert "same" not in self.out_features
+ assert "scalar::same" not in self.out_features
+ self.output_scalars = dataset.get_scalar_names()
+ elif ("same" in self.out_features) or ("scalar::same" in self.out_features):
+ self.output_scalars = self.input_scalars
+ else:
+ self.output_scalars = [
+ s[8:] for s in self.out_features if s[:8] == "scalar::"
+ ]
+
+ def _determine_time_series_names(self, dataset: Dataset):
+ """Determine the input/output time_series names based on the in_features and out_features."""
+ # Input time_series
+ if ("all" in self.in_features) or ("time_series::all" in self.in_features):
+ self.input_time_series = dataset.get_time_series_names()
+ else:
+ self.input_time_series = [
+ s[8:] for s in self.in_features if s[:8] == "time_series::"
+ ]
+
+ # Output time_series
+ if ("all" in self.out_features) or ("time_series::all" in self.out_features):
+ assert "same" not in self.out_features
+ assert "time_series::same" not in self.out_features
+ self.output_time_series = dataset.get_time_series_names()
+ elif ("same" in self.out_features) or (
+ "time_series::same" in self.out_features
+ ):
+ self.output_time_series = self.input_time_series
+ else:
+ self.output_time_series = [
+ s[8:] for s in self.out_features if s[:8] == "time_series::"
+ ]
+
+ def _determine_field_names(self, dataset: Dataset):
+ """Determine the input/output field names based on the in_features and out_features."""
+ # default_time = dataset[dataset.get_sample_ids()[0]].get_time_assignment()
+ # default_base = dataset[dataset.get_sample_ids()[0]].get_base_assignment(time=default_time)
+ # has_no_default_base = (default_base is None)
+ # if has_no_default_base:
+ # default_zone = dataset[dataset.get_sample_ids()[0]].get_zone_assignment(time=default_time)
+ # else:
+ # default_zone = dataset[dataset.get_sample_ids()[0]].get_zone_assignment(base_name=default_base, time=default_time)
+ # has_no_default_zone = (default_zone is None)
+
+ # Input fields
+ if ("all" in self.in_features) or ("field::all" in self.in_features):
+ self.input_fields = dataset.get_field_names()
+ else:
+ self.input_fields = [s[7:] for s in self.in_features if s[:7] == "field::"]
+
+ # Output fields
+ if ("all" in self.out_features) or ("field::all" in self.out_features):
+ assert "same" not in self.out_features
+ assert "field::same" not in self.out_features
+ self.output_fields = dataset.get_field_names()
+ elif ("same" in self.out_features) or ("field::same" in self.out_features):
+ self.output_fields = self.input_fields
+ else:
+ self.output_fields = [
+ s[7:] for s in self.out_features if s[:7] == "field::"
+ ]
+
+ def _extract_X_y_from_plaid(
+ self, dataset: Dataset
+ ) -> tuple[np.ndarray, np.ndarray]:
+ """Extract features (X) and labels (y) from a PLAID dataset according to the input/output keys.
+
+ Args:
+ dataset (Dataset): The dataset to extract data from.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: The extracted features and labels as numpy arrays.
+ """
+ ### Inputs
+ X = [
+ dataset.get_scalars_to_tabular([input_scalar_name], as_nparray=True)
+ for input_scalar_name in (
+ dataset.get_scalar_names()
+ if self.input_scalars == "all"
+ else self.input_scalars
+ )
+ ]
+ X.extend(
+ [
+ dataset.get_time_series_to_tabular(
+ [input_time_series_name], as_nparray=True
+ )
+ for input_time_series_name in self.input_time_series
+ ]
+ )
+ X.extend(
+ [
+ dataset.get_fields_to_tabular([input_field_name], as_nparray=True)
+ for input_field_name in self.input_fields # TODO: handle names with '/'
+ ]
+ )
+ # Check shapes
+ for i_v, v in enumerate(X):
+ # Reshape any 3D arrays to 2D, contracting the last two dimensions
+ if len(v.shape) >= 3:
+ X[i_v] = v.reshape((len(v), -1))
+ # Reshape any 1D arrays to 2D, appending a singleton dimension
+ if len(v.shape) == 1:
+ X[i_v] = v.reshape((-1, 1))
+ # print(f"=== In <_extract_X_y_from_plaid> of {self.sklearn_block=}")
+ # print(f"{self.input_scalars=}")
+ # print(f"{self.input_fields=}")
+ # print(f"{self.output_scalars=}")
+ # print(f"{self.output_fields=}")
+ # print(f"{type(X)=}")
+ # print(f"{len(X)=}")
+ # Concatenate the input arrays into a 2D numpy array
+ X = np.concatenate(X, axis=-1)
+
+ ### Outputs
+ y = [
+ dataset.get_scalars_to_tabular([output_scalar_name], as_nparray=True)
+ for output_scalar_name in (
+ dataset.get_scalar_names()
+ if self.output_scalars == "all"
+ else self.output_scalars
+ )
+ ]
+ y.extend(
+ [
+ dataset.get_time_series_to_tabular(
+ [output_time_series_name], as_nparray=True
+ )
+ for output_time_series_name in self.output_time_series
+ ]
+ )
+ y.extend(
+ [
+ dataset.get_fields_to_tabular([output_field_name], as_nparray=True)
+ for output_field_name in self.output_fields # TODO: handle names with '/'
+ ]
+ )
+ # Check shapes
+ for i_v, v in enumerate(y):
+ # Reshape any 3D arrays to 2D, contracting the last two dimensions
+ if len(v.shape) >= 3:
+ y[i_v] = v.reshape((len(v), -1))
+ # Reshape any 1D arrays to 2D, appending a singleton dimension
+ if len(v.shape) == 1:
+ y[i_v] = v.reshape((-1, 1))
+ # print(f"{self.input_scalars=}")
+ # print(f"{self.input_fields=}")
+ # print(f"{self.output_scalars=}")
+ # print(f"{self.output_fields=}")
+ # print(f"{type(y)=}")
+ # print(f"{len(y)=}")
+ # Concatenate the output arrays into a 2D numpy array
+ if len(y) > 0:
+ y = np.concatenate(y, axis=-1)
+ else:
+ y = None
+
+ return X, y
+
+ def _convert_y_to_plaid(self, y: np.ndarray, dataset: Dataset) -> Dataset:
+ """Convert the model's output (numpy array) to a PLAID Dataset, updating the original dataset.
+
+ Args:
+ y (np.ndarray): The model's output.
+ dataset (Dataset): The original dataset.
+
+ Returns:
+ Dataset: The updated PLAID dataset with new scalars/fields.
+ """
+ # TODO: use https://scikit-learn.org/stable/glossary.html#term-get_feature_names_out avoid overwriting features
+ print(f"=== In <_convert_y_to_plaid>")
+ if hasattr(self.sklearn_block, "feature_names_in_"):
+ print(f"- {self.sklearn_block.feature_names_in_=}")
+ else:
+ print("- self.sklearn_block.feature_names_in_ not found")
+ print(f"- {self.sklearn_block.get_feature_names_out()=}")
+ print(f"- {self.output_scalars=}")
+ print(f"- {self.output_time_series=}")
+ print(f"- {self.output_fields=}")
+ print(f"- {dataset.get_scalar_names()=}")
+ print(f"- {dataset.get_time_series_names()=}")
+ print(f"- {dataset.get_field_names()=}")
+ print(f"- {y.shape=}")
+
+ new_dset = Dataset()
+ # TODO: define tests to determine if we write new features to scalars, fields, or time series
+ if y.ndim == 2 and y.shape[0] == len(self.sklearn_block.get_feature_names_out()):
+ new_dset.add_tabular_scalars(y, self.sklearn_block.get_feature_names_out())
+ elif len(self.output_scalars) > 0:
+ new_dset.add_tabular_scalars(y, self.output_scalars)
+
+ # if len(self.output_scalars) > 0:
+ # new_dset.add_tabular_scalars(
+ # y[:, : len(self.output_scalars)], self.output_scalars
+ # )
+ # if len(self.output_time_series) > 0:
+ # new_dset.add_tabular_time_series(
+ # y[:, len(self.output_scalars) : len(self.output_scalars) + len(self.output_time_series)],
+ # self.output_time_series
+ # )
+ # if len(self.output_fields) > 0:
+ # new_dset.add_tabular_fields(
+ # y[:, len(self.output_scalars) + len(self.output_time_series) :],
+ # self.output_fields
+ # )
+
+ dataset.merge_samples(new_dset)
+ print(f"- {dataset.get_scalar_names()=}")
+ print(f"- {dataset.get_time_series_names()=}")
+ print(f"- {dataset.get_field_names()=}")
+ return dataset
+
+ def __sklearn_is_fitted__(self):
+ """Check if the wrapped scikit-learn model is fitted.
+
+ Returns:
+ bool: True if the model is fitted, False otherwise.
+ """
+ return hasattr(self, "_is_fitted") and self._is_fitted
+
+ def __repr__(self):
+ """String representation of the wrapper, showing the underlying sklearn block."""
+ return f"{self.__class__.__name__}({self.sklearn_block.__repr__()})"
+
+ def __str__(self):
+ """String representation of the wrapper, showing the underlying sklearn block."""
+ return f"{self.__class__.__name__}({self.sklearn_block.__str__()})"
+
+
+class WrappedSklearnTransform(PlaidWrapper, TransformerMixin):
+ """Wrapper for scikit-learn Transformer blocks to operate on PLAID objects in a Pipeline.
+
+ This class allows you to use any sklearn Transformer (e.g. PCA, StandardScaler) in a Pipeline where all steps
+ accept and return PLAID Dataset objects. The transform and inverse_transform methods take a Dataset and return a new Dataset.
+ """
+
+ def transform(self, dataset: Dataset):
+ """Transform the dataset using the wrapped sklearn transformer.
+
+ Args:
+ dataset (Dataset): The dataset to transform.
+
+ Returns:
+ Dataset: The transformed PLAID dataset.
+ """
+ X, _ = self._extract_X_y_from_plaid(dataset)
+ X_transformed = self.sklearn_block.transform(X)
+ return self._convert_y_to_plaid(X_transformed, dataset)
+
+ def inverse_transform(self, dataset: Dataset):
+ """Inverse transform the dataset using the wrapped sklearn transformer.
+
+ Args:
+ dataset (Dataset): The dataset to inverse transform.
+
+ Returns:
+ Dataset: The inverse transformed PLAID dataset.
+ """
+ # TODO: debug
+ X, _ = self._extract_X_y_from_plaid(dataset)
+ X_transformed = self.sklearn_block.inverse_transform(X)
+ return self._convert_y_to_plaid(X_transformed, dataset)
+
+ ## Already defined by TransformerMixin
+ # def fit_transform(self, dataset:Dataset):...
+
+
+class WrappedSklearnPredictor(PlaidWrapper, MetaEstimatorMixin):
+ """Wrapper for scikit-learn Predictor blocks to operate on PLAID objects in a Pipeline.
+
+ This class allows you to use any sklearn predictor (e.g. GaussianProcessRegressor, RandomForestRegressor, etc.) in a Pipeline
+ where all steps accept and return PLAID Dataset objects. The predict and fit_predict methods take a Dataset and return a new Dataset.
+ """
+
+ def predict(self, dataset: Dataset):
+ """Predict the output for the given dataset using the wrapped sklearn predictor.
+
+ Args:
+ dataset (Dataset): The dataset to predict.
+
+ Returns:
+ Dataset: The predicted PLAID dataset.
+ """
+ X, _ = self._extract_X_y_from_plaid(dataset)
+ y_pred = self.sklearn_block.predict(X)
+ return self._convert_y_to_plaid(y_pred, dataset)
+
+ def fit_predict(self, dataset: Dataset):
+ """Fit the model to the dataset and predict the output using the wrapped sklearn predictor.
+
+ Args:
+ dataset (Dataset): The dataset to fit the model on.
+
+ Returns:
+ Dataset: The predicted PLAID dataset.
+ """
+ self.fit(dataset)
+ return self.predict(dataset)
+
+
+class WrappedSklearnClassifier(WrappedSklearnPredictor, ClassifierMixin):
+ """Wrapper for scikit-learn Classifier blocks to operate on PLAID objects in a Pipeline.
+
+ Inherits from WrappedSklearnPredictor and ClassifierMixin.
+ """
+
+ pass
+
+
+class WrappedSklearnRegressor(WrappedSklearnPredictor, RegressorMixin):
+ """Wrapper for scikit-learn Regressor blocks to operate on PLAID objects in a Pipeline.
+
+ Inherits from WrappedSklearnPredictor and RegressorMixin.
+ """
+
+ pass
diff --git a/tests/containers/test_dataset.py b/tests/containers/test_dataset.py
index d13c4fb..f0637f8 100644
--- a/tests/containers/test_dataset.py
+++ b/tests/containers/test_dataset.py
@@ -361,6 +361,47 @@ def test_get_scalars_to_tabular_same_scalars_name(
dataset.get_scalars_to_tabular(sample_ids=[0, 0])
dataset.get_scalars_to_tabular(scalar_names=["test", "test"])
+ # -------------------------------------------------------------------------#
+ def test_add_tabular_fields(self, dataset, tabular, field_names, nb_samples):
+ dataset.add_tabular_fields(tabular, field_names)
+ assert len(dataset) == nb_samples
+
+ def test_add_tabular_fields_no_names(self, dataset, tabular, nb_samples):
+ dataset.add_tabular_fields(tabular)
+ assert len(dataset) == nb_samples
+
+ def test_add_tabular_fields_bad_ndim(self, dataset, tabular, field_names):
+ with pytest.raises(ShapeError):
+ dataset.add_tabular_fields(tabular.reshape((-1)), field_names)
+
+ def test_add_tabular_fields_bad_shape(self, dataset, tabular, field_names):
+ tabular = np.concatenate((tabular, np.zeros((len(tabular), 1))), axis=1)
+ with pytest.raises(ShapeError):
+ dataset.add_tabular_fields(tabular, field_names)
+
+ def test_get_fields_to_tabular(self, dataset, tabular, field_names):
+ assert len(dataset.get_fields_to_tabular()) == 0
+ assert dataset.get_fields_to_tabular() == {}
+ dataset.add_tabular_fields(tabular, field_names)
+ assert dataset.get_fields_to_tabular(as_nparray=True).shape == (
+ len(tabular),
+ len(field_names),
+ )
+ dict_tabular = dataset.get_fields_to_tabular()
+ for i_s, sname in enumerate(field_names):
+ assert np.all(dict_tabular[sname] == tabular[:, i_s])
+
+ def test_get_fields_to_tabular_same_fields_name(
+ self, dataset, tabular, field_names
+ ):
+ dataset.add_tabular_fields(tabular, field_names)
+ assert dataset.get_fields_to_tabular(as_nparray=True).shape == (
+ len(tabular),
+ len(field_names),
+ )
+ dataset.get_fields_to_tabular(sample_ids=[0, 0])
+ dataset.get_fields_to_tabular(field_names=["test", "test"])
+
# -------------------------------------------------------------------------#
def test_add_info(self, dataset):
dataset.add_info("legal", "owner", "PLAID")
@@ -415,6 +456,9 @@ def test_merge_dataset_with_bad_type(self, dataset_with_samples):
with pytest.raises(ValueError):
dataset_with_samples.merge_dataset(3)
+ def test_merge_samples(self, dataset_with_samples, other_dataset_with_samples):
+ dataset_with_samples.merge_samples(other_dataset_with_samples)
+
# -------------------------------------------------------------------------#
def test_save(self, dataset_with_samples, tmp_path):
diff --git a/tests/problem_definition/problem_infos.yaml b/tests/problem_definition/problem_infos.yaml
index bb9be46..fc1845c 100644
--- a/tests/problem_definition/problem_infos.yaml
+++ b/tests/problem_definition/problem_infos.yaml
@@ -13,8 +13,8 @@ input_fields:
- test_field
output_fields:
- field
-- test_field
- predict_field
+- test_field
input_timeseries:
- predict_timeseries
- test_timeseries