Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making wrapper tensor subclass to work in serialization #2440

Merged
merged 18 commits into from
Aug 30, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 6, 2024

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id by returning a tuple constructed from all storage_ids from internal plain or tensor subclassed tensors

Note: This PR only supported non-safetensor serialization for tensor subclasses

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:

@jerryzh168 jerryzh168 changed the title Making wrapper tensor subclass to work in huggingface_hub serializati… Making wrapper tensor subclass to work in serialization Aug 6, 2024
@jerryzh168 jerryzh168 force-pushed the non-safetensor-ser branch 2 times, most recently from 40413f5 to d6c4256 Compare August 8, 2024 17:01
@jerryzh168
Copy link
Contributor Author

@SunMarc can you find someone to review this?

@SunMarc SunMarc requested a review from Wauplin August 9, 2024 01:42
@SunMarc
Copy link
Member

SunMarc commented Aug 9, 2024

@SunMarc can you find someone to review this?

Yes ! Thanks for your work ! I've asked @Wauplin to review. It would be nice to have a few tests if possible !

…on (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/safetensors that referenced this pull request Aug 9, 2024
Summary:
Similar to huggingface/huggingface_hub#2440 we want to allow
safetensor to handle wrapper tensor subclasses, we mainly added:

1. tensor storage size: this is done through flattening the wrapper tensor subclass and add up the storage
 size of all sub tensors recursively
2. storage_ptr: this is done by constructing a tuple given the "storage_ptr" for flattened tensors, this could
 be a nested tuple of tuple of int, e.g. ((1, 2), 3, (4, (5, 6),),)

Test Plan:
Added a test in test_pt_model.py, will also test manually

Reviewers:

Subscribers:

Tasks:

Tags:
@Wauplin
Copy link
Contributor

Wauplin commented Aug 12, 2024

Hi @jerryzh168, thanks for the opening this PR! And thanks @SunMarc for pulling me in this convo.

The general logic of this PR looks ok to me -even though I'm missing some broader context I think-. The tests look good as well. However, as a huggingface_hub maintainer, what I'm the most afraid of is to rely too heavily on pytorch internals. Typically from torch.utils._python_dispatch import is_traceable_wrapper_subclass and tensor.__tensor_flatten__() are relying on private / internal methods that I'm not familiar with. It will become harder to maintain in the future especially if we want to keep huggingface_hub compatible with "every" pytorch versions in the future (e.g. what if an internal helper is moved?).

In the end, the only thing that I want for _get_unique_id or get_torch_storage_id is to get a unique id from a pytorch tensor. Would it be possible to expose a generic helper that handles all the cases in pytorch directly? I feel that adding more and more logic to get unique ids in huggingface_hub is not the correct place. Note that this logic is currently duplicated in transformers and accelerate as well but we are in the process of centralizing it to huggingface_hub, so at least it's a first step towards uniformization :)

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Aug 14, 2024

@Wauplin I can add this to pytorch, but it will only be available in nightlies or torch 2.5+, so don't we still need to add this in huggingface_hub or other places for now? what is the version requirement for huggingface_hub/safetensors etc.? do they plan to work with all different torch versions?

jerryzh168 added a commit to pytorch/pytorch that referenced this pull request Aug 15, 2024
Summary:
Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of
`get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses

Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead

This PR added
`get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors
`get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors

Test Plan:
python test/test_utils.py TestStorageUtils

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to pytorch/pytorch that referenced this pull request Aug 15, 2024
Summary:
Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of
`get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses

Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead

This PR added
`get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors
`get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors

Test Plan:
python test/test_utils.py TestStorageUtils

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jerryzh168 added a commit to pytorch/pytorch that referenced this pull request Aug 15, 2024
Summary:
Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of
`get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses

Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead

This PR added
`get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors
`get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors

Test Plan:
python test/test_utils.py TestStorageUtils

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 155983cad1176187ef703b69ad06651cbf2ccd83
Pull Request resolved: #133524
jerryzh168 added a commit to pytorch/pytorch that referenced this pull request Aug 15, 2024
Summary:
Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of
`get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses

Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead

This PR added
`get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors
`get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors

Test Plan:
python test/test_utils.py TestStorageUtils

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jerryzh168 added a commit to pytorch/pytorch that referenced this pull request Aug 15, 2024
Summary:
Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of
`get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses

Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead

This PR added
`get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors
`get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors

Test Plan:
python test/test_utils.py TestStorageUtils

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 87d1bd591a8cb1f926738f8e251fc56d8cd9e3f2
Pull Request resolved: #133524
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168, thanks for the changes! The implementation looks good to me at first glance. I've added a comment regarding the version parsing thing. In the meantime, I'll run some tests locally on my side.

In a follow-up PR, we'll add CI tests for torch 2.0 and 2.5 (for instance) to be sure both versions are compatible. I can take care of that part.

src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
@jerryzh168
Copy link
Contributor Author

@Wauplin please take a look again, thanks!

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes @jerryzh168! I left some comments below

src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved

Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, two "meta" tensors can have the exact same _get_unique_id(tensor), the exact same tensor.device but still be different, correct? If different, how can we be sure their storage size distinguish them? Can it happen that they randomly happen to have the same storage size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it just means the current approach does not generalize to meta tensor, does it work previously?

I think we'd need to reimplement the higher level sharding logic in the end in pytorch, I added some PoC in the slack, let me make a quick intro there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it just means the current approach does not generalize to meta tensor, does it work previously?

I don't think so since we never had to serialize meta tensors. The only use case that could benefit from that is in accelerate (find tied parameters from the meta model). Right now, this is how we do for meta tensors: https://github.com/huggingface/accelerate/blob/726140cad2f2361d79da7786a7b96d0bee591c48/src/accelerate/utils/modeling.py#L677

tests/test_serialization.py Outdated Show resolved Hide resolved
tests/test_serialization.py Outdated Show resolved Hide resolved
tests/test_serialization.py Outdated Show resolved Hide resolved
tests/test_serialization.py Outdated Show resolved Hide resolved
Comment on lines 137 to 139
# TODO: need to fix safetensor support for tensor subclasses before we can add this
# to test
# shared_layer = TwoTensor(torch.tensor([4]), torch.tensor([4]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin this seem to fail because safetensor does not support wrapper tensor subclass yet, we can enable this when we add the similar support in safetensors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we don't have to test on save_torch_state_dict using safetensors. It is possible to simply test with split_torch_state_dict_into_shards (needs to be imported) since what we really care about is the grouping of the tensors, not necessarily the serialization -for now at least-. Could you update the fixture and test in that direction?

Also, torch_state_dict_shared_layers is already taken as a name for a fixture so code quality is complaining in the CI.

@Wauplin
Copy link
Contributor

Wauplin commented Aug 26, 2024

I pushed a commit to fix some linting + merge from main so that we are now testing the pipeline on both pytorch~=1.11 and latest pytorch version. This way I assume we will be able to spot breaking changes in the future while still testing on previous versions up to 24 months.

Last remaining thing is the test to complete IMO (see #2440 (comment)). Otherwise, everything looks good!

@jerryzh168
Copy link
Contributor Author

@Wauplin @SunMarc I added the test for tensor subclass for split_torch_state_dict_into_shards, please take a look again, thanks!

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jerryzh168! Everything looks good to me now! I fixed a few formatting issues so we're now ready to merge this. Thanks again and thanks @SunMarc as well for the inputs :)

Looking forward to see deeper integration into pytorch directly!

@Wauplin Wauplin merged commit f12ba86 into huggingface:main Aug 30, 2024
16 checks passed
jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Sep 13, 2024
…nfig quantized model

Summary:
After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Sep 13, 2024
…nfig quantized model

Summary:
After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Sep 19, 2024
…nfig quantized model

Summary:
After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Sep 25, 2024
…nfig quantized model

Summary:
After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/transformers that referenced this pull request Sep 25, 2024
…nfig quantized model

Summary:
After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
ArthurZucker pushed a commit to huggingface/transformers that referenced this pull request Sep 30, 2024
…33456)

* Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model

Summary:
After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:

* formatting

* formatting

* minor fix

* formatting

* address comments

* comments

* minor fix

* update doc

* refactor compressed tensor quantizer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants