diff --git a/tests/core/test_anndata.py b/tests/core/test_anndata.py index ebb5a6df..5a0abd16 100644 --- a/tests/core/test_anndata.py +++ b/tests/core/test_anndata.py @@ -2,7 +2,7 @@ import hypothesis.strategies as st import pytest -from hypothesis import given +from hypothesis import given, settings import numpy as np import pandas as pd @@ -112,8 +112,23 @@ def test_different_obs_id_length( class TestCleanup(TestBase): - @given(adata=get_adata(max_obs=5, max_vars=5), inplace=st.booleans()) - def test_cleanup_all(self, adata: AnnData, inplace: bool): + @pytest.mark.parametrize("dataset", ["pancreas", "dentategyrus"]) + @pytest.mark.parametrize("n_obs", [50, 100]) + @pytest.mark.parametrize("layer", [None, "unspliced", "spliced"]) + @pytest.mark.parametrize("dense", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_cleanup_all( + self, adata, dataset: str, n_obs: int, layer: bool, dense: bool, inplace: bool + ): + adata = adata(dataset=dataset, n_obs=n_obs, raw=False, preprocessed=True) + adata.layers["dummy_layer"] = csr_matrix(np.eye(adata.n_obs, adata.n_vars)) + adata.uns["dummy_entry"] = {"key": csr_matrix(np.eye(5, 7))} + + if dense: + if layer is None: + adata.X = adata.X.A + else: + adata.layers[layer] = adata.layers[layer].A returned_adata = cleanup(adata=adata, clean="all", inplace=inplace) if not inplace: @@ -122,12 +137,13 @@ def test_cleanup_all(self, adata: AnnData, inplace: bool): else: assert returned_adata is None - assert len(adata.layers) == 0 - assert len(adata.uns) == 0 + assert list(adata.layers.keys()) == ["Ms", "Mu", "spliced", "unspliced"] + assert list(adata.uns.keys()) == ["neighbors"] assert len(adata.obs.columns) == 0 assert len(adata.var.columns) == 0 @given(adata=get_adata(max_obs=5, max_vars=5), inplace=st.booleans()) + @settings(max_examples=10, deadline=1000) def test_cleanup_default_clean_w_random_adata(self, adata: AnnData, inplace: bool): n_obs_cols = len(adata.obs.columns) n_var_cols = len(adata.var.columns) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 06ed961a..f5bcd1f6 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -18,7 +18,7 @@ def dentategyrus_adata(tmpdir_factory): @pytest.mark.skipif( - sys.version_info[:2] != (3, 8) or sys.platform != "linux", + sys.version_info[:2] != (3, 10) or sys.platform != "linux", reason="Limit number of downloads to speed up testing.", ) class TestDataSets: