Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle shared layers in save_torch_state_dict + add save_torch_model #2373

Merged
merged 3 commits into from
Jul 11, 2024

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Jul 4, 2024

Partially resolve #2065.
Follow-up PR after #2314.

In #2314, we introduce save_torch_state_dict. This new PR:

  • adds logic to deduplicate shared layers in safetensors. This is mostly taken from safetensors's torch helpers (see here). See slack thread (private) for discussions around this. See also https://huggingface.co/docs/safetensors/torch_shared_tensors for more details.
  • adds save_torch_model to directly save a torch nn.Module
  • renames internal methods get_tf_storage_size / get_torch_storage_size and make them public + documented
  • tests and documentation have also been updated.

A last follow-up PR should had load_torch_state_dict / load_torch_model helpers as well to correctly reload those files, including the shared layers.

I'm pinging transformers/accelerate/diffusers cores maintainers for visibility as well. Feel free to comment if someone should be done differently.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`].
The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.

If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.
Copy link
Member

@sayakpaul sayakpaul Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point of mentioning this but also I think for the Torch community, it's fairly standard practice to ship the model classes and their state dictionaries (i.e., the parameters) separately unlike TensorFlow/Keras, for example.

)


def get_tensor_size(tensor: "tf.Tensor") -> int:
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have all the equivalent torch methods for TensorFlow? Or is that not necessary?

Copy link
Contributor Author

@Wauplin Wauplin Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet no. Let's build for torch first and then expand to TF after if needed. For now for TF we have the logic to split a state dict into shards but nothing to save to disk.

Comment on lines +249 to +251
"metadata": {**state_dict_split.metadata, **metadata},
"weight_map": state_dict_split.tensor_to_filename,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be any sanity check on the additional metadata if not already done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata is at the discretion of the frameworks that will use it (transformers/diffusers/accelerate). In practice, I don't think it'll be much used. In any case, we can't really do sanity check since we are supposed to accept anything that is jsonable.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understanding the full scope of the PR is still a little farfetched for me but I left some clarification questions.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much needed change! The API w/ contiguous looks ok to me.

Thanks for the changes! Let's see if it breaks in the wild but from a quick check it looks good.

@Wauplin
Copy link
Contributor Author

Wauplin commented Jul 11, 2024

Thanks for the reviews! Let's ship it yes 😄

@Wauplin Wauplin merged commit dfe72d0 into main Jul 11, 2024
17 checks passed
@Wauplin Wauplin deleted the 2065-safer-safe-state-dict branch July 11, 2024 14:08
Wauplin added a commit that referenced this pull request Jul 11, 2024
* Use extended path on Windows when downloading to local dir

Change the path of the local dir to an extended path by prepending
"\\?\" to the absolute path, when the absolute path is longer
than 255 characters on Windows.

Also fixed a small typo.

* Use extended path on Windows when downloading to local dir

Change the path of the local dir to an extended path by prepending
"\\?\" to the absolute path, when the absolute path is longer
than 255 characters on Windows.

Also fixed a small typo.

* Move path handling to `get_local_download_paths()` for robustness

On Windows we check the length of `lock_path` and if it is longer than
255 characters we prepend the `\\?\` prefix to all paths if it does not
already exist.

We only need to check the length of `lock_path` because it is guaranteed
to be the longest path.

* `safetensors[torch]` (#2371)

* Fix token=False not respected in file download (#2386)

* Fix token=False not respected in file download

* lint

* Handle shared layers in `save_torch_state_dict` + add `save_torch_model` (#2373)

* Handle shared layers in save_torch_state_dict + save_torch_model + some helpers

* fix pytest rerun

* more reruns

* Support `expand` parameter in `xxx_info` and `list_xxxs` (model/dataset/Space) (#2333)

* First draft to support `expand` parameter for models

* add expand support for dataset

* add expand support for Space

* Use extended path on Windows when downloading to local dir

Change the path of the local dir to an extended path by prepending
"\\?\" to the absolute path, when the absolute path is longer
than 255 characters on Windows.

Also fixed a small typo.

* Move path handling to `get_local_download_paths()` for robustness

On Windows we check the length of `lock_path` and if it is longer than
255 characters we prepend the `\\?\` prefix to all paths if it does not
already exist.

We only need to check the length of `lock_path` because it is guaranteed
to be the longest path.

* Use extended path on Windows when downloading to local dir

Change the path of the local dir to an extended path by prepending
"\\?\" to the absolute path, when the absolute path is longer
than 255 characters on Windows.

Also fixed a small typo.

* Removed old path handling

* Reorder path check; add tests

* Skip test if opn Windows

The test now shows up a `skipped` if executed on a non-Windows machine

Co-authored-by: Lucain <lucainp@gmail.com>

* Fix indentation for test_local_folder.py

* Fix code style

---------

Co-authored-by: Lucain <lucain@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Lucain <lucainp@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement save_state_dict and load_state_dict in serialization module
4 participants