From 486896862341e29b2b55cd03e851774b8da69088 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 3 May 2024 09:00:21 +0200 Subject: [PATCH] Support filtering datasets by tags --- src/huggingface_hub/hf_api.py | 28 +++++++++++++++++----------- tests/test_hf_api.py | 5 +++++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e526a81efa..c573c05967 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -1540,13 +1540,13 @@ def list_models( >>> api = HfApi() - >>> # List all models + # List all models >>> api.list_models() - >>> # List only the text classification models + # List only the text classification models >>> api.list_models(filter="text-classification") - >>> # List only models from the AllenNLP library + # List only models from the AllenNLP library >>> api.list_models(filter="allennlp") ``` @@ -1557,10 +1557,10 @@ def list_models( >>> api = HfApi() - >>> # List all models with "bert" in their name + # List all models with "bert" in their name >>> api.list_models(search="bert") - >>> # List all models with "bert" in their name made by google + # List all models with "bert" in their name made by google >>> api.list_models(search="bert", author="google") ``` """ @@ -1698,6 +1698,7 @@ def list_datasets( language: Optional[Union[str, List[str]]] = None, multilinguality: Optional[Union[str, List[str]]] = None, size_categories: Optional[Union[str, List[str]]] = None, + tags: Optional[Union[str, List[str]]] = None, task_categories: Optional[Union[str, List[str]]] = None, task_ids: Optional[Union[str, List[str]]] = None, search: Optional[str] = None, @@ -1736,6 +1737,8 @@ def list_datasets( A string or list of strings that can be used to identify datasets on the Hub by the size of the dataset such as `100K>> api = HfApi() - >>> # List all datasets + # List all datasets >>> api.list_datasets() - >>> # List only the text classification datasets + # List only the text classification datasets >>> api.list_datasets(filter="task_categories:text-classification") - >>> # List only the datasets in russian for language modeling + # List only the datasets in russian for language modeling >>> api.list_datasets( ... filter=("language:ru", "task_ids:language-modeling") ... ) - >>> api.list_datasets(filter=filt) + # List FiftyOne datasets (identified by the tag "fiftyone" in dataset card) + >>> api.list_datasets(tags="fiftyone") ``` Example usage with the `search` argument: @@ -1798,10 +1802,10 @@ def list_datasets( >>> api = HfApi() - >>> # List all datasets with "text" in their name + # List all datasets with "text" in their name >>> api.list_datasets(search="text") - >>> # List all datasets with "text" in their name made by google + # List all datasets with "text" in their name made by google >>> api.list_datasets(search="text", author="google") ``` """ @@ -1839,6 +1843,8 @@ def list_datasets( data = f"{attr}:{data}" filter_list.append(data) + if tags is not None: + filter_list.extend([tags] if isinstance(tags, str) else tags) if search: params.update({"search": search}) if sort is not None: diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 9fa21bb273..b64a266939 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -1713,6 +1713,11 @@ def test_filter_datasets_with_card_data(self): datasets = list(self._api.list_datasets(limit=500)) self.assertTrue(all([getattr(dataset, "card_data", None) is None for dataset in datasets])) + def test_filter_datasets_by_tag(self): + datasets = list(self._api.list_datasets(tags="fiftyone", limit=5)) + for dataset in datasets: + assert "fiftyone" in dataset.tags + def test_dataset_info(self): dataset = self._api.dataset_info(repo_id=DUMMY_DATASET_ID) self.assertTrue(isinstance(dataset.card_data, DatasetCardData) and len(dataset.card_data) > 0)