Skip to content

🎉 Sklearn pipelines wrapper #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,711 changes: 2,711 additions & 0 deletions docs/source/notebooks/pca_gp_plaid_pipeline.ipynb

Large diffs are not rendered by default.

7,899 changes: 7,899 additions & 0 deletions docs/source/notebooks/pca_gp_sklearn_pipeline.ipynb

Large diffs are not rendered by default.

378 changes: 365 additions & 13 deletions src/plaid/containers/dataset.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/plaid/containers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def read_index_range(pyTree: list, dim: list[int]):


class Sample(BaseModel):
"""Represents a single sample. It contains data and information related to a single observation or measurement within a dataset."""
"""Represent a single sample. It contains data and information related to a single observation or measurement within a dataset."""

def __init__(
self,
Expand Down
7 changes: 7 additions & 0 deletions src/plaid/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Wrapper functions for the PLAID library."""

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
496 changes: 496 additions & 0 deletions src/plaid/wrappers/sklearn.py

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions tests/containers/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,47 @@ def test_get_scalars_to_tabular_same_scalars_name(
dataset.get_scalars_to_tabular(sample_ids=[0, 0])
dataset.get_scalars_to_tabular(scalar_names=["test", "test"])

# -------------------------------------------------------------------------#
def test_add_tabular_fields(self, dataset, tabular, field_names, nb_samples):
dataset.add_tabular_fields(tabular, field_names)
assert len(dataset) == nb_samples

def test_add_tabular_fields_no_names(self, dataset, tabular, nb_samples):
dataset.add_tabular_fields(tabular)
assert len(dataset) == nb_samples

def test_add_tabular_fields_bad_ndim(self, dataset, tabular, field_names):
with pytest.raises(ShapeError):
dataset.add_tabular_fields(tabular.reshape((-1)), field_names)

def test_add_tabular_fields_bad_shape(self, dataset, tabular, field_names):
tabular = np.concatenate((tabular, np.zeros((len(tabular), 1))), axis=1)
with pytest.raises(ShapeError):
dataset.add_tabular_fields(tabular, field_names)

def test_get_fields_to_tabular(self, dataset, tabular, field_names):
assert len(dataset.get_fields_to_tabular()) == 0
assert dataset.get_fields_to_tabular() == {}
dataset.add_tabular_fields(tabular, field_names)
assert dataset.get_fields_to_tabular(as_nparray=True).shape == (
len(tabular),
len(field_names),
)
dict_tabular = dataset.get_fields_to_tabular()
for i_s, sname in enumerate(field_names):
assert np.all(dict_tabular[sname] == tabular[:, i_s])

def test_get_fields_to_tabular_same_fields_name(
self, dataset, tabular, field_names
):
dataset.add_tabular_fields(tabular, field_names)
assert dataset.get_fields_to_tabular(as_nparray=True).shape == (
len(tabular),
len(field_names),
)
dataset.get_fields_to_tabular(sample_ids=[0, 0])
dataset.get_fields_to_tabular(field_names=["test", "test"])

# -------------------------------------------------------------------------#
def test_add_info(self, dataset):
dataset.add_info("legal", "owner", "PLAID")
Expand Down Expand Up @@ -415,6 +456,9 @@ def test_merge_dataset_with_bad_type(self, dataset_with_samples):
with pytest.raises(ValueError):
dataset_with_samples.merge_dataset(3)

def test_merge_samples(self, dataset_with_samples, other_dataset_with_samples):
dataset_with_samples.merge_samples(other_dataset_with_samples)

# -------------------------------------------------------------------------#

def test_save(self, dataset_with_samples, tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/problem_definition/problem_infos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ input_fields:
- test_field
output_fields:
- field
- test_field
- predict_field
- test_field
input_timeseries:
- predict_timeseries
- test_timeseries
Expand Down
Loading