Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: torch.complex64 when attempting to save PyTorch model #450

Closed
1 of 2 tasks
NielsRogge opened this issue Mar 11, 2024 · 4 comments
Closed
1 of 2 tasks

KeyError: torch.complex64 when attempting to save PyTorch model #450

NielsRogge opened this issue Mar 11, 2024 · 4 comments
Labels

Comments

@NielsRogge
Copy link

NielsRogge commented Mar 11, 2024

System Info

safetensors v0.4.2
huggingface_hub v0.22.0.dev0

Information

  • The official example scripts
  • My own modified scripts

Reproduction

We recently switched to leveraging Safetensors by default for the PyTorchModelHubMixin class in huggingface_hub (huggingface/huggingface_hub#2033), which is a minimal class that adds from_pretrained and push_to_hub methods to any custom nn.Module.

However, when trying out this class on the Gemma series of models by Google, I get the following error when calling push_to_hub (which first saves the tensors in the safetensors format before uploading the files to the hub):

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-8-eac8a21155c9>](https://localhost:8080/#) in <cell line: 1>()
----> 1 model.push_to_hub(f"nielsr/gemma-2b-it")

8 frames
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py](https://localhost:8080/#) in inner_f(*args, **kwargs)
     99                     message += "\n\n" + custom_message
    100                 warnings.warn(message, FutureWarning)
--> 101             return f(*args, **kwargs)
    102 
    103         return inner_f

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py](https://localhost:8080/#) in _inner_fn(*args, **kwargs)
    117             kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
    118 
--> 119         return fn(*args, **kwargs)
    120 
    121     return _inner_fn  # type: ignore

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in push_to_hub(self, repo_id, config, commit_message, private, token, branch, create_pr, allow_patterns, ignore_patterns, delete_patterns, api_endpoint)
    517         with SoftTemporaryDirectory() as tmp:
    518             saved_path = Path(tmp) / repo_id
--> 519             self.save_pretrained(saved_path, config=config)
    520             return api.upload_folder(
    521                 repo_id=repo_id,

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in save_pretrained(self, save_directory, config, repo_id, push_to_hub, **push_to_hub_kwargs)
    247 
    248         # save model weights/files (framework-specific)
--> 249         self._save_pretrained(save_directory)
    250 
    251         # save config (if provided)

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in _save_pretrained(self, save_directory)
    590         """Save weights from a Pytorch model to a local directory."""
    591         model_to_save = self.module if hasattr(self, "module") else self  # type: ignore
--> 592         save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
    593 
    594     @classmethod

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in save_model(model, filename, metadata, force_contiguous)
    153     """
    154     state_dict = model.state_dict()
--> 155     to_removes = _remove_duplicate_names(state_dict)
    156 
    157     for kept_name, to_remove_group in to_removes.items():

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in _remove_duplicate_names(state_dict, preferred_names, discard_names)
     98     to_remove = defaultdict(list)
     99     for shared in shareds:
--> 100         complete_names = set([name for name in shared if _is_complete(state_dict[name])])
    101         if not complete_names:
    102             raise RuntimeError(

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in <listcomp>(.0)
     98     to_remove = defaultdict(list)
     99     for shared in shareds:
--> 100         complete_names = set([name for name in shared if _is_complete(state_dict[name])])
    101         if not complete_names:
    102             raise RuntimeError(

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in _is_complete(tensor)
     79 
     80 def _is_complete(tensor: torch.Tensor) -> bool:
---> 81     return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[tensor.dtype] == storage_size(tensor)
     82 
     83 

KeyError: torch.complex64

Here's a notebook for reproduction.

Expected behavior

This model has some tensors of type torch.complex64, would be great to save those.

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Apr 11, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Apr 17, 2024
@NielsRogge NielsRogge reopened this Apr 17, 2024
@github-actions github-actions bot removed the Stale label Apr 18, 2024
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label May 18, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale May 23, 2024
@npuichigo
Copy link

same issue here

@NielsRogge
Copy link
Author

Friendly pinging @Narsil here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants