Skip to content

Commit

Permalink
Feature: switch visibility with update_repo_settings #2537 (#2541)
Browse files Browse the repository at this point in the history
* Enhance `update_repo_settings` to manage repo visibility

* Enhance `update_repo_settings` to manage repo visibility

* Enhance `update_repo_settings` to manage repo visibility

* Enhance `update_repo_settings` to manage repo visibility

* Enhance `update_repo_settings` to manage repo visibility

* Apply suggestions from code review

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
WizKnight and Wauplin authored Sep 26, 2024
1 parent 12eb785 commit f984cdc
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 31 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/guides/repository.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ Some settings are specific to Spaces (hardware, environment variables,...). To c
A repository can be public or private. A private repository is only visible to you or members of the organization in which the repository is located. Change a repository to private as shown in the following:

```py
>>> from huggingface_hub import update_repo_visibility
>>> update_repo_visibility(repo_id=repo_id, private=True)
>>> from huggingface_hub import update_repo_settings
>>> update_repo_settings(repo_id=repo_id, private=True)
```

### Setup gated access
Expand Down
42 changes: 33 additions & 9 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3524,6 +3524,7 @@ def delete_repo(
if not missing_ok:
raise

@_deprecate_method(version="0.29", message="Please use `update_repo_settings` instead.")
@validate_hf_hub_args
def update_repo_visibility(
self,
Expand All @@ -3535,6 +3536,8 @@ def update_repo_visibility(
) -> Dict[str, bool]:
"""Update the visibility setting of a repository.
Deprecated. Use `update_repo_settings` instead.
Args:
repo_id (`str`, *optional*):
A namespace (user or an organization) and a repo name separated by a `/`.
Expand Down Expand Up @@ -3581,29 +3584,34 @@ def update_repo_settings(
self,
repo_id: str,
*,
gated: Literal["auto", "manual", False] = False,
gated: Optional[Literal["auto", "manual", False]] = None,
private: Optional[bool] = None,
token: Union[str, bool, None] = None,
repo_type: Optional[str] = None,
) -> None:
"""
Update the gated settings of a repository.
To give more control over how repos are used, the Hub allows repo authors to enable **access requests** for their repos.
Update the settings of a repository, including gated access and visibility.
To give more control over how repos are used, the Hub allows repo authors to enable
access requests for their repos, and also to set the visibility of the repo to private.
Args:
repo_id (`str`):
A namespace (user or an organization) and a repo name separated by a /.
gated (`Literal["auto", "manual", False]`, *optional*):
The gated release status for the repository.
The gated status for the repository. If set to `None` (default), the `gated` setting of the repository won't be updated.
* "auto": The repository is gated, and access requests are automatically approved or denied based on predefined criteria.
* "manual": The repository is gated, and access requests require manual approval.
* False (default): The repository is not gated, and anyone can access it.
* False : The repository is not gated, and anyone can access it.
private (`bool`, *optional*):
Whether the model repo should be private.
token (`Union[str, bool, None]`, *optional*):
A valid user access token (string). Defaults to the locally saved token,
which is the recommended method for authentication (see
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
To disable authentication, pass False.
repo_type (`str`, *optional*):
The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`.
The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`).
Defaults to `"model"`.
Raises:
Expand All @@ -3613,22 +3621,38 @@ def update_repo_settings(
If repo_type is not one of the values in constants.REPO_TYPES.
[`~utils.HfHubHTTPError`]:
If the request to the Hugging Face Hub API fails.
[`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
"""
if gated not in ["auto", "manual", False]:
raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.")

if repo_type not in constants.REPO_TYPES:
raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}")
if repo_type is None:
repo_type = constants.REPO_TYPE_MODEL # default repo type

# Check if both gated and private are None
if gated is None and private is None:
raise ValueError("At least one of 'gated' or 'private' must be provided.")

# Build headers
headers = self._build_hf_headers(token=token)

# Prepare the JSON payload for the PUT request
payload: Dict = {}

if gated is not None:
if gated not in ["auto", "manual", False]:
raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.")
payload["gated"] = gated

if private is not None:
payload["private"] = private

r = get_session().put(
url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings",
headers=headers,
json={"gated": gated},
json=payload,
)
hf_raise_for_status(r)

Expand Down
1 change: 1 addition & 0 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_download_from_a_gated_repo_with_hf_hub_download(self, repo_url: RepoUrl
repo_id=repo_url.repo_id, filename=".gitattributes", token=OTHER_TOKEN, cache_dir=tmpdir
)

@expect_deprecation("update_repo_visibility")
@use_tmp_repo()
def test_download_regular_file_from_private_renamed_repo(self, repo_url: RepoUrl) -> None:
"""Regression test for #1999.
Expand Down
36 changes: 17 additions & 19 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import types
import unittest
import uuid
import warnings
from collections.abc import Iterable
from concurrent.futures import Future
from dataclasses import fields
Expand Down Expand Up @@ -93,6 +92,7 @@
DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT,
ENDPOINT_PRODUCTION,
SAMPLE_DATASET_IDENTIFIER,
expect_deprecation,
repo_name,
require_git_lfs,
rmtree_with_retry,
Expand Down Expand Up @@ -124,18 +124,6 @@ def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)


def test_repo_id_no_warning():
# tests that passing repo_id as positional arg doesn't raise any warnings
# for {create, delete}_repo and update_repo_visibility
api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)

with warnings.catch_warnings(record=True) as record:
repo_id = api.create_repo(repo_name()).repo_id
api.update_repo_visibility(repo_id, private=True)
api.delete_repo(repo_id)
assert not len(record)


class HfApiRepoFileExistsTest(HfApiCommonTest):
def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -210,6 +198,7 @@ def test_delete_repo_error_message(self):
def test_delete_repo_missing_ok(self) -> None:
self._api.delete_repo("repo-that-does-not-exist", missing_ok=True)

@expect_deprecation("update_repo_visibility")
def test_create_update_and_delete_repo(self):
repo_id = self._api.create_repo(repo_id=repo_name()).repo_id
res = self._api.update_repo_visibility(repo_id=repo_id, private=True)
Expand All @@ -218,6 +207,7 @@ def test_create_update_and_delete_repo(self):
assert not res["private"]
self._api.delete_repo(repo_id=repo_id)

@expect_deprecation("update_repo_visibility")
def test_create_update_and_delete_model_repo(self):
repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_MODEL).repo_id
res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_MODEL)
Expand All @@ -226,6 +216,7 @@ def test_create_update_and_delete_model_repo(self):
assert not res["private"]
self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_MODEL)

@expect_deprecation("update_repo_visibility")
def test_create_update_and_delete_dataset_repo(self):
repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_DATASET).repo_id
res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_DATASET)
Expand All @@ -234,6 +225,7 @@ def test_create_update_and_delete_dataset_repo(self):
assert not res["private"]
self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_DATASET)

@expect_deprecation("update_repo_visibility")
def test_create_update_and_delete_space_repo(self):
with pytest.raises(ValueError, match=r"No space_sdk provided.*"):
self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_SPACE, space_sdk=None)
Expand Down Expand Up @@ -286,19 +278,25 @@ def test_update_repo_settings(self, repo_url: RepoUrl):
repo_id = repo_url.repo_id

for gated_value in ["auto", "manual", False]:
self._api.update_repo_settings(repo_id=repo_id, gated=gated_value)
info = self._api.model_info(repo_id, expand="gated")
assert info.gated == gated_value
for private_value in [True, False]: # Test both private and public settings
self._api.update_repo_settings(repo_id=repo_id, gated=gated_value, private=private_value)
info = self._api.model_info(repo_id)
assert info.gated == gated_value
assert info.private == private_value # Verify the private setting

@use_tmp_repo(repo_type="dataset")
def test_update_dataset_repo_settings(self, repo_url: RepoUrl):
repo_id = repo_url.repo_id
repo_type = repo_url.repo_type

for gated_value in ["auto", "manual", False]:
self._api.update_repo_settings(repo_id=repo_id, repo_type=repo_type, gated=gated_value)
info = self._api.dataset_info(repo_id, expand="gated")
assert info.gated == gated_value
for private_value in [True, False]:
self._api.update_repo_settings(
repo_id=repo_id, repo_type=repo_type, gated=gated_value, private=private_value
)
info = self._api.dataset_info(repo_id)
assert info.gated == gated_value
assert info.private == private_value


class CommitApiTest(HfApiCommonTest):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from huggingface_hub.utils import SoftTemporaryDirectory

from .testing_constants import TOKEN
from .testing_utils import OfflineSimulationMode, offline, repo_name
from .testing_utils import OfflineSimulationMode, expect_deprecation, offline, repo_name


class SnapshotDownloadTests(unittest.TestCase):
Expand Down Expand Up @@ -95,6 +95,7 @@ def test_download_model(self):
# folder name contains the revision's commit sha.
self.assertTrue(self.first_commit_hash in storage_folder)

@expect_deprecation("update_repo_visibility")
def test_download_private_model(self):
self.api.update_repo_visibility(repo_id=self.repo_id, private=True)

Expand Down

0 comments on commit f984cdc

Please sign in to comment.