Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making wrapper tensor subclass to work in serialization #2440

Merged
merged 18 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
save_torch_model,
save_torch_state_dict,
split_torch_state_dict_into_shards,
torch_version_at_least,
)
69 changes: 57 additions & 12 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Any

from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
Expand All @@ -32,6 +32,26 @@
import torch


def parse_version(version_string):
# Extract just the X.Y.Z part from the version string
match = re.match(r'(\d+\.\d+\.\d+)', version_string)
if match:
version = match.group(1)
return [int(x) for x in version.split('.')]
else:
raise ValueError(f"Invalid version string format: {version_string}")

def compare_versions(v1, v2):
v1_parts = parse_version(v1)
v2_parts = parse_version(v2)
return (v1_parts > v2_parts) - (v1_parts < v2_parts)

def torch_version_at_least(min_version):
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
try:
return compare_versions(torch.__version__, min_version) >= 0
except:
return False

def save_torch_model(
model: "torch.nn.Module",
save_directory: Union[str, Path],
Expand Down Expand Up @@ -335,18 +355,18 @@ def split_torch_state_dict_into_shards(
get_storage_id=get_torch_storage_id,
)


def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""Returns a unique id for plain tensor
or a (potentially nested) Tuple of unique id for the flattened Tensor
if the input is a wrapper tensor subclass Tensor
"""
Return unique identifier to a tensor storage.

Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
if torch_version_at_least("2.1.0"):
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__()
unique_id = tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand All @@ -358,13 +378,33 @@ def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, i
else:
unique_id = storage_ptr(tensor)

return tensor.device, unique_id, get_torch_storage_size(tensor)
return unique_id

def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]:
"""
Return unique identifier to a tensor storage.

Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.

Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, two "meta" tensors can have the exact same _get_unique_id(tensor), the exact same tensor.device but still be different, correct? If different, how can we be sure their storage size distinguish them? Can it happen that they randomly happen to have the same storage size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it just means the current approach does not generalize to meta tensor, does it work previously?

I think we'd need to reimplement the higher level sharding logic in the end in pytorch, I added some PoC in the slack, let me make a quick intro there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it just means the current approach does not generalize to meta tensor, does it work previously?

I don't think so since we never had to serialize meta tensors. The only use case that could benefit from that is in accelerate (find tied parameters from the meta model). Right now, this is how we do for meta tensors: https://github.com/huggingface/accelerate/blob/726140cad2f2361d79da7786a7b96d0bee591c48/src/accelerate/utils/modeling.py#L677



def get_torch_storage_size(tensor: "torch.Tensor") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
"""
if torch_version_at_least("2.1.0"):
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__()
print(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)

try:
return tensor.untyped_storage().nbytes()
except AttributeError:
Expand Down Expand Up @@ -398,10 +438,15 @@ def is_torch_tpu_available(check_device=True):
return False


def storage_ptr(tensor: "torch.Tensor") -> int:
def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
"""
if torch_version_at_least("2.1.0"):
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
return _get_unique_id(tensor)

try:
return tensor.untyped_storage().data_ptr()
except Exception:
Expand Down
105 changes: 105 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
save_torch_model,
save_torch_state_dict,
split_state_dict_into_shards_factory,
torch_version_at_least,
)
from huggingface_hub.serialization._base import parse_size_to_int

Expand Down Expand Up @@ -58,6 +59,25 @@ def torch_state_dict() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": TwoTensor(t, t),
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
try:
Expand All @@ -75,6 +95,52 @@ def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
tensor_subclass_tensor = TwoTensor(t, t)

t = torch.tensor([4])
shared_tensor_subclass_tensor = TwoTensor(t, t)
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": tensor_subclass_tensor,
"ts_shared_1": shared_tensor_subclass_tensor,
"ts_shared_2": shared_tensor_subclass_tensor,
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

if torch_version_at_least("2.1.0"):
shared_layer = TwoTensor(torch.tensor([4]), torch.tensor([4]))
else:
shared_layer = torch.tensor([4])

return {
"shared_1": shared_layer,
"unique_1": torch.tensor([10]),
"unique_2": torch.tensor([30]),
"shared_2": shared_layer,
}
except ImportError:
pytest.skip("torch is not available")


def test_single_shard(dummy_state_dict):
state_dict_split = split_state_dict_into_shards_factory(
dummy_state_dict,
Expand Down Expand Up @@ -170,6 +236,17 @@ def test_get_torch_storage_size():
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2


@requires("torch")
def test_get_torch_storage_size_wrapper_tensor_subclass():
import torch
if torch_version_at_least("2.1.0"):
from torch.testing._internal.two_tensor import TwoTensor
t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)
assert get_torch_storage_size(TwoTensor(t, t)) == 5 * 8 * 2
t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
assert get_torch_storage_size(TwoTensor(t, TwoTensor(t, t))) == 5 * 2 * 3


def test_parse_size_to_int():
assert parse_size_to_int("1KB") == 1 * 10**3
assert parse_size_to_int("2MB") == 2 * 10**6
Expand Down Expand Up @@ -247,6 +324,34 @@ def test_save_torch_state_dict_unsafe_not_sharded(
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
if not torch_version_at_least("2.1.0"):
return
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB", safe_serialization=False)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
if not torch_version_at_least("2.1.0"):
return
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB", safe_serialization=False)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_unsafe_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"]
) -> None:
Expand Down
Loading