diff --git a/auto_trainer/requirements.txt b/auto_trainer/requirements.txt index b14a0293c..e8e8c1b64 100644 --- a/auto_trainer/requirements.txt +++ b/auto_trainer/requirements.txt @@ -1,4 +1,4 @@ pandas -scikit-learn<1.4.0 +scikit-learn>=1.5.0 xgboost<2.0.0 plotly diff --git a/auto_trainer/test_auto_trainer.py b/auto_trainer/test_auto_trainer.py index 9a1ff554c..94822455b 100644 --- a/auto_trainer/test_auto_trainer.py +++ b/auto_trainer/test_auto_trainer.py @@ -72,9 +72,7 @@ def _get_dataset(problem_type: str, filepath: str = ".", n_classes: int = 2): def _assert_train_handler(train_run): - assert train_run and all( - key in train_run.outputs for key in ["model", "test_set"] - ), "outputs should include more data" + assert train_run and (train_run.status.artifact_uris or train_run.status.artifacts) , "outputs should include more data" @pytest.mark.parametrize("model", MODELS) diff --git a/describe/requirements.txt b/describe/requirements.txt index a96b6ff1b..6066ce7e9 100644 --- a/describe/requirements.txt +++ b/describe/requirements.txt @@ -1,4 +1,4 @@ -scikit-learn~=1.0.2 +scikit-learn==1.5.0 plotly~=5.16.1 pytest~=7.0.1 matplotlib~=3.5.1 diff --git a/gen_class_data/requirements.txt b/gen_class_data/requirements.txt index d7dbe376b..f781e922b 100644 --- a/gen_class_data/requirements.txt +++ b/gen_class_data/requirements.txt @@ -1,2 +1,2 @@ pandas -scikit-learn==1.0.2 \ No newline at end of file +scikit-learn==1.5.0 \ No newline at end of file diff --git a/gen_class_data/test_gen_class_data.py b/gen_class_data/test_gen_class_data.py index e06eeb16b..52e517b98 100644 --- a/gen_class_data/test_gen_class_data.py +++ b/gen_class_data/test_gen_class_data.py @@ -36,4 +36,4 @@ def test_gen_class_data(): local=True, artifact_path="./artifacts", ) - assert os.path.isfile(run.status.artifacts[0]['spec']['target_path']), 'dataset is not available' + assert "classifier-data" in run.status.artifact_uris, 'dataset was not logged'