Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support filtering datasets by tags #2266

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,13 +1540,13 @@ def list_models(

>>> api = HfApi()

>>> # List all models
# List all models
>>> api.list_models()

>>> # List only the text classification models
# List only the text classification models
>>> api.list_models(filter="text-classification")

>>> # List only models from the AllenNLP library
# List only models from the AllenNLP library
>>> api.list_models(filter="allennlp")
```

Expand All @@ -1557,10 +1557,10 @@ def list_models(

>>> api = HfApi()

>>> # List all models with "bert" in their name
# List all models with "bert" in their name
>>> api.list_models(search="bert")

>>> # List all models with "bert" in their name made by google
# List all models with "bert" in their name made by google
>>> api.list_models(search="bert", author="google")
```
"""
Expand Down Expand Up @@ -1698,6 +1698,7 @@ def list_datasets(
language: Optional[Union[str, List[str]]] = None,
multilinguality: Optional[Union[str, List[str]]] = None,
size_categories: Optional[Union[str, List[str]]] = None,
tags: Optional[Union[str, List[str]]] = None,
task_categories: Optional[Union[str, List[str]]] = None,
task_ids: Optional[Union[str, List[str]]] = None,
search: Optional[str] = None,
Expand Down Expand Up @@ -1736,6 +1737,8 @@ def list_datasets(
A string or list of strings that can be used to identify datasets on
the Hub by the size of the dataset such as `100K<n<1M` or
`1M<n<10M`.
tags (`str` or `List`, *optional*):
A string tag or a list of tags to filter datasets on the Hub.
task_categories (`str` or `List`, *optional*):
A string or list of strings that can be used to identify datasets on
the Hub by the designed task, such as `audio_classification` or
Expand Down Expand Up @@ -1775,20 +1778,21 @@ def list_datasets(

>>> api = HfApi()

>>> # List all datasets
# List all datasets
>>> api.list_datasets()


>>> # List only the text classification datasets
# List only the text classification datasets
>>> api.list_datasets(filter="task_categories:text-classification")


>>> # List only the datasets in russian for language modeling
# List only the datasets in russian for language modeling
>>> api.list_datasets(
... filter=("language:ru", "task_ids:language-modeling")
... )

>>> api.list_datasets(filter=filt)
# List FiftyOne datasets (identified by the tag "fiftyone" in dataset card)
>>> api.list_datasets(tags="fiftyone")
```

Example usage with the `search` argument:
Expand All @@ -1798,10 +1802,10 @@ def list_datasets(

>>> api = HfApi()

>>> # List all datasets with "text" in their name
# List all datasets with "text" in their name
>>> api.list_datasets(search="text")

>>> # List all datasets with "text" in their name made by google
# List all datasets with "text" in their name made by google
>>> api.list_datasets(search="text", author="google")
```
"""
Expand Down Expand Up @@ -1839,6 +1843,8 @@ def list_datasets(
data = f"{attr}:{data}"
filter_list.append(data)

if tags is not None:
filter_list.extend([tags] if isinstance(tags, str) else tags)
if search:
params.update({"search": search})
if sort is not None:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,11 @@ def test_filter_datasets_with_card_data(self):
datasets = list(self._api.list_datasets(limit=500))
self.assertTrue(all([getattr(dataset, "card_data", None) is None for dataset in datasets]))

def test_filter_datasets_by_tag(self):
datasets = list(self._api.list_datasets(tags="fiftyone", limit=5))
for dataset in datasets:
assert "fiftyone" in dataset.tags

def test_dataset_info(self):
dataset = self._api.dataset_info(repo_id=DUMMY_DATASET_ID)
self.assertTrue(isinstance(dataset.card_data, DatasetCardData) and len(dataset.card_data) > 0)
Expand Down
Loading