-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PretrainedModule.save_pretrained(safe_serialization=False)
does not work with PyTorch wrapper tensor subclasses
#32364
Comments
…on (non-safetensor) Summary: huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas. Test Plan: tested with script in huggingface/transformers#32364 Reviewers: Subscribers: Tasks: Tags:
…on (non-safetensor) Summary: huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas. Test Plan: tested with script in huggingface/transformers#32364 Reviewers: Subscribers: Tasks: Tags:
…on (non-safetensor) Summary: huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas. Test Plan: tested with script in huggingface/transformers#32364 Reviewers: Subscribers: Tasks: Tags:
…on (non-safetensor) Summary: huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas. Test Plan: tested with script in huggingface/transformers#32364 Reviewers: Subscribers: Tasks: Tags:
* Making wrapper tensor subclass to work in huggingface_hub serialization (non-safetensor) Summary: huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas. Test Plan: tested with script in huggingface/transformers#32364 Reviewers: Subscribers: Tasks: Tags: * add tests * update signature to include new changes for tensor subclass * add torch version checks and move around import * more fixes * tested with torch 2.0.0 and 2.5.0 * remove torch_version_at_least from _torch.py * simplify code for checking if tensor subclass is available or not * minor fix * addressing comments and run tests with torch 2.4.0 * some linting * add test_split_torch_state_dict_into_shards for tensor subclass state dict * lint * style * quality --------- Co-authored-by: Lucain <lucain@huggingface.co> Co-authored-by: Lucain Pouget <lucainp@gmail.com>
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hello, thanks for your issue! Following the merge of huggingface/huggingface_hub#2440, should this issue be closed? |
Yes, I think that was the fix! |
Awesome! Closing this then. |
System Info
transformers
version: 4.43.3Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
safe_serialization=False
uses pytorchtorch.save
serialization, which should work with wrapper tensor subclasses. However, the huggingface internal utils for this are accessingtensor.storage().data_ptr()
, which does not work for wrapper tensor subclasses.The specific use case is
torchao
quantized tensors that are implemented as wrapper tensor subclasses. The following is a minimal repro with a basic wrapper tensor subclass,TwoTensor
, but this should apply for any tensor subclass fromtorchao
that is used as the model parametersClick for stack trace
Expected behavior
It looks like the code in question creates a unique id for the tensors storage via the data_ptr. The outer tensor of the wrapper subclass does not have a "real" storage, so we expect the access to the storage data_ptr to fail. However, wrapper tensor subclasses have "inner" tensors that have "real" storages though, so this could be an option for getting an id.
cc @jerryzh168
The text was updated successfully, but these errors were encountered: