Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization: support saving torch state dict to disk #2314

Merged
merged 12 commits into from
Jun 7, 2024

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Jun 4, 2024

Implement save_torch_state_dict to save a torch state dictionary to disk (first part of #2065). It uses split_torch_state_dict_into_shards under the hood (#1938).

State dict is saved either to a single file (if less than 5GB) or to shards with the corresponding index.json. By default, shards are saved as safetensors but safe_serialization=False can be passed to save as pickle. A warning is logged when saving as pickle and hopefully we should be able to dropped support for it when transformers/diffusers/accelerate/... completely phase out from .bin saving. cc @LysandreJik I'd like to get your opinion on this. I'm fine with not adding support for .bin files at all but worry it would slow down adoption in our libraries.

For the implementation, I took inspiration from huggingface/diffusers#7830 + accelerate/transformers. What it does:

  1. Split state dict into shard (logic already exists)
  2. Clean existing directory (remove previous shard/index files)
  3. Write shards to disk
  4. Write index to disk (optional)

Example:

>>> from huggingface_hub import save_torch_state_dict
>>> model = ... # A PyTorch model

# Save state dict to "path/to/folder"
# The model is split into shards of 5GB each and saved as safetensors.
>>> state_dict = model_to_save.state_dict()
>>> save_torch_state_dict(state_dict, "path/to/folder")

cc @amyeroberts / @ArthurZucker for transformers, @sayakpaul for diffusers, @SunMarc @muellerzr for accelerate
Happy to get feedback on this type of critical part. The goal is to standardize things to be consistent across libraries so please let me know if you want to add/remove something!

(documentation has also been updated)


note: I also removed split_numpy_state_dict_into_shards which is a breaking change but I don't expect anything to break in the wild. Better to just remove it to avoid future maintenance (I shouldn't have added it in the first place).

(failing CI is unrelated)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Wauplin Wauplin changed the title [Draft] Save Pytorch state dict Save Pytorch state dict Jun 5, 2024
@Wauplin Wauplin marked this pull request as ready for review June 5, 2024 15:20
@Wauplin Wauplin changed the title Save Pytorch state dict Serialization: support saving torch state dict to disk Jun 5, 2024
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nicely done.

@Wauplin Wauplin requested a review from SunMarc June 6, 2024 08:24
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this ! LGTM ! Just a small nit

Comment on lines 167 to 179
if safe_serialization:
filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN

try:
from safetensors.torch import save_file as save_file_fn
except ImportError as e:
raise ImportError(
"Please install `safetensors` to use safe serialization. "
"You can install it with `pip install safetensors`."
) from e

else:
filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename_pattern is modified even when the user passes a filename_pattern != None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed it in 5c4cac7.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice ❤️

docs/source/en/package_reference/serialization.md Outdated Show resolved Hide resolved
assert (tmp_path / "model.variant-00002-of-00002.safetensors").is_file()


def test_save_torch_state_dict_delete_existing_files(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@Wauplin
Copy link
Contributor Author

Wauplin commented Jun 6, 2024

Thanks everyone for the reviews! ❤️

@Wauplin Wauplin merged commit 122a057 into main Jun 7, 2024
15 of 16 checks passed
@Wauplin Wauplin deleted the 2065-save-state-dict branch June 7, 2024 13:38
@Wauplin Wauplin mentioned this pull request Jul 17, 2024
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants