Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PretrainedModule.save_pretrained(safe_serialization=False) does not work with PyTorch wrapper tensor subclasses #32364

Closed
1 of 4 tasks
mikaylagawarecki opened this issue Jul 31, 2024 · 4 comments
Labels

Comments

@mikaylagawarecki
Copy link

mikaylagawarecki commented Jul 31, 2024

System Info

  • transformers version: 4.43.3
  • Platform: Linux-5.12.0-0_fbk16_zion_7661_geb00762ce6d2-x86_64-with-glibc2.34
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.3
  • Accelerate version: 0.22.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0a0+gitc35f21e (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: no
  • GPU type: NVIDIA PG509-210

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

safe_serialization=False uses pytorch torch.save serialization, which should work with wrapper tensor subclasses. However, the huggingface internal utils for this are accessing tensor.storage().data_ptr(), which does not work for wrapper tensor subclasses.

The specific use case is torchao quantized tensors that are implemented as wrapper tensor subclasses. The following is a minimal repro with a basic wrapper tensor subclass, TwoTensor, but this should apply for any tensor subclass from torchao that is used as the model parameters

import torch
from torch.testing._internal.two_tensor import TwoTensor
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load pre-trained model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Convert every parameter/buffer to a wrapper tensor subclass (TwoTensor) for demonstration purposes
model._apply(lambda t: TwoTensor(t, t))

# Save the model and tokenizer to a directory
output_dir = "./my-bert-model"
model.save_pretrained(output_dir, safe_serialization=False)
Click for stack trace
  return tensor.storage().data_ptr()
Traceback (most recent call last):
  File "/home/mg1998/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 406, in storage_ptr
    return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data/users/mg1998/pytorch/test_transformers.py", line 11, in <module>
    model.save_pretrained(output_dir, safe_serialization=False)
  File "/home/mg1998/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2691, in save_pretrained
    state_dict_split = split_torch_state_dict_into_shards(
  File "/home/mg1998/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 330, in split_torch_state_dict_into_shards
    return split_state_dict_into_shards_factory(
  File "/home/mg1998/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/huggingface_hub/serialization/_base.py", line 108, in split_state_dict_into_shards_factory
    storage_id = get_storage_id(tensor)
  File "/home/mg1998/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 359, in get_torch_storage_id
    unique_id = storage_ptr(tensor)
  File "/home/mg1998/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/huggingface_hub/serialization/_torch.py", line 410, in storage_ptr
    return tensor.storage().data_ptr()
  File "/data/users/mg1998/pytorch/torch/storage.py", line 1220, in data_ptr
    return self._data_ptr()
  File "/data/users/mg1998/pytorch/torch/storage.py", line 1224, in _data_ptr
    return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

Expected behavior

It looks like the code in question creates a unique id for the tensors storage via the data_ptr. The outer tensor of the wrapper subclass does not have a "real" storage, so we expect the access to the storage data_ptr to fail. However, wrapper tensor subclasses have "inner" tensors that have "real" storages though, so this could be an option for getting an id.

cc @jerryzh168

jerryzh168 added a commit to jerryzh168/huggingface_hub that referenced this issue Aug 6, 2024
…on (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/huggingface_hub that referenced this issue Aug 6, 2024
…on (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/huggingface_hub that referenced this issue Aug 8, 2024
…on (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/huggingface_hub that referenced this issue Aug 9, 2024
…on (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:
Wauplin added a commit to huggingface/huggingface_hub that referenced this issue Aug 30, 2024
* Making wrapper tensor subclass to work in huggingface_hub serialization (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:

* add tests

* update signature to include new changes for tensor subclass

* add torch version checks and move around import

* more fixes

* tested with torch 2.0.0 and 2.5.0

* remove torch_version_at_least from _torch.py

* simplify code for checking if tensor subclass is available or not

* minor fix

* addressing comments and run tests with torch 2.4.0

* some linting

* add test_split_torch_state_dict_into_shards for tensor subclass state dict

* lint

* style

* quality

---------

Co-authored-by: Lucain <lucain@huggingface.co>
Co-authored-by: Lucain Pouget <lucainp@gmail.com>
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@LysandreJik
Copy link
Member

Hello, thanks for your issue! Following the merge of huggingface/huggingface_hub#2440, should this issue be closed?

@mikaylagawarecki
Copy link
Author

Yes, I think that was the fix!

@LysandreJik
Copy link
Member

Awesome! Closing this then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants