Skip to content

Commit 78543e6

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves (pytorch#137397)
Pull Request resolved: pytorch#137397 Approved by: https://github.com/jansel
1 parent 9990b47 commit 78543e6

File tree

8 files changed

+195
-61
lines changed

8 files changed

+195
-61
lines changed

test/dynamo/test_misc.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import torch._dynamo.testing
3333
import torch._inductor.test_case
3434
import torch.onnx.operators
35-
import torch.utils._pytree as pytree
35+
import torch.utils._pytree as python_pytree
3636
import torch.utils.cpp_extension
3737
from torch import Tensor
3838
from torch._C import FileCheck
@@ -89,9 +89,11 @@
8989
from torch.testing._internal.logging_utils import logs_to_string
9090

9191

92-
HAS_OPTREE = importlib.util.find_spec("optree")
92+
HAS_OPTREE = python_pytree._cxx_pytree_exists
9393
if HAS_OPTREE:
94-
import optree
94+
import torch.utils._cxx_pytree as cxx_pytree
95+
else:
96+
cxx_pytree = None
9597

9698
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
9799
T = typing.TypeVar("T")
@@ -293,9 +295,9 @@ def fn(x):
293295

294296
@unittest.skipIf(not HAS_OPTREE, "missing optree package")
295297
def test_optree_graph_break_message(self):
296-
@torch.compile(
297-
backend="eager",
298-
)
298+
import optree
299+
300+
@torch.compile(backend="eager")
299301
def fn(x):
300302
d = {"a": 1}
301303
optree.tree_flatten(d)
@@ -8722,9 +8724,9 @@ def fn():
87228724

87238725
def test_tracing_py_tree(self):
87248726
def fn(xs):
8725-
flat_xs, spec = pytree.tree_flatten(xs)
8727+
flat_xs, spec = python_pytree.tree_flatten(xs)
87268728
res = [x.clone() for x in flat_xs]
8727-
return pytree.tree_unflatten(res, spec)
8729+
return python_pytree.tree_unflatten(res, spec)
87288730

87298731
xs = [torch.tensor(i) for i in range(3)]
87308732

@@ -8734,12 +8736,10 @@ def fn(xs):
87348736
self.assertEqual(counter.op_count, 3)
87358737

87368738
def test_tracing_nested_py_tree(self):
8737-
import torch.utils._pytree as pytree
8738-
87398739
def fn(xs):
8740-
flat_xs, spec = pytree.tree_flatten(xs)
8740+
flat_xs, spec = python_pytree.tree_flatten(xs)
87418741
res = [x.clone() for x in flat_xs]
8742-
return pytree.tree_unflatten(res, spec)
8742+
return python_pytree.tree_unflatten(res, spec)
87438743

87448744
xs = [torch.tensor(i) for i in range(3)]
87458745
xsl = [xs, xs, xs, xs]
@@ -8752,12 +8752,10 @@ def fn(xs):
87528752
self.assertEqual(counter.op_count, 12)
87538753

87548754
def test_tracing_nested_py_tree_tuples(self):
8755-
import torch.utils._pytree as pytree
8756-
87578755
def fn(xs):
8758-
flat_xs, spec = pytree.tree_flatten(xs)
8756+
flat_xs, spec = python_pytree.tree_flatten(xs)
87598757
res = [x.clone() for x in flat_xs]
8760-
return pytree.tree_unflatten(res, spec)
8758+
return python_pytree.tree_unflatten(res, spec)
87618759

87628760
xs = [torch.tensor(i) for i in range(3)]
87638761
xsl = (xs, xs, xs, xs)
@@ -8770,12 +8768,10 @@ def fn(xs):
87708768
self.assertEqual(counter.op_count, 12)
87718769

87728770
def test_tracing_nested_py_tree_dicts(self):
8773-
import torch.utils._pytree as pytree
8774-
87758771
def fn(xs):
8776-
flat_xs, spec = pytree.tree_flatten(xs)
8772+
flat_xs, spec = python_pytree.tree_flatten(xs)
87778773
res = [x.clone() for x in flat_xs]
8778-
return pytree.tree_unflatten(res, spec)
8774+
return python_pytree.tree_unflatten(res, spec)
87798775

87808776
xs = [torch.tensor(i) for i in range(3)]
87818777
xsl = {
@@ -8808,12 +8804,10 @@ def fn(x):
88088804
self.assertEqual(counter.op_count, 2)
88098805

88108806
def test_tracing_nested_py_tree_mixed_all(self):
8811-
import torch.utils._pytree as pytree
8812-
88138807
def fn(xs):
8814-
flat_xs, spec = pytree.tree_flatten(xs)
8808+
flat_xs, spec = python_pytree.tree_flatten(xs)
88158809
res = [x.clone() for x in flat_xs]
8816-
return pytree.tree_unflatten(res, spec)
8810+
return python_pytree.tree_unflatten(res, spec)
88178811

88188812
xs = [torch.tensor(i) for i in range(3)]
88198813
xsa = (xs, xs)
@@ -8858,13 +8852,12 @@ def fn(x):
88588852
self.assertEqual(cnt.frame_count, 2)
88598853

88608854
def test_tracing_py_tree_tensor_subclass(self):
8861-
import torch.utils._pytree as pytree
88628855
from torch.testing._internal.two_tensor import TwoTensor
88638856
from torch.utils.checkpoint import checkpoint
88648857

88658858
def fn(xs):
88668859
nested_xs = [[xs]]
8867-
flat_xs, spec = pytree.tree_flatten(xs)
8860+
flat_xs, spec = python_pytree.tree_flatten(xs)
88688861
return flat_xs[0].clone()
88698862

88708863
# use checkpoint to trigger a "sourceless" tensor subclass
@@ -8879,13 +8872,11 @@ def checkpoint_fn(xs):
88798872
self.assertEqual(counter.op_count, 2)
88808873

88818874
def test_tracing_tree_map_only(self):
8882-
import torch.utils._pytree as pytree
8883-
88848875
def fn(xs):
88858876
def mapper(x):
88868877
return x.clone()
88878878

8888-
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
8879+
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
88898880
return y
88908881

88918882
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
@@ -10235,7 +10226,9 @@ def fn(x, y):
1023510226
self.assertEqual(actual, expected)
1023610227

1023710228
def test_pytree_tree_leaves(self):
10238-
implemtations = [("python", pytree)]
10229+
implemtations = [("python", python_pytree)]
10230+
if cxx_pytree is not None:
10231+
implemtations.append(("cxx", cxx_pytree))
1023910232

1024010233
for name, module in implemtations:
1024110234
with self.subTest(f"pytree implement: {name}"):
@@ -10267,7 +10260,7 @@ def fn(x):
1026710260
self.assertEqual(actual, expected)
1026810261

1026910262
def test_pytree_tree_flatten_unflatten(self):
10270-
implemtations = [("python", pytree)]
10263+
implemtations = [("python", python_pytree)]
1027110264

1027210265
for name, module in implemtations:
1027310266
with self.subTest(f"pytree implement: {name}"):
@@ -10316,7 +10309,7 @@ def fn(x, y):
1031610309
self.assertEqual(actual, expected)
1031710310

1031810311
def test_pytree_tree_map(self):
10319-
implemtations = [("python", pytree)]
10312+
implemtations = [("python", python_pytree)]
1032010313

1032110314
for name, module in implemtations:
1032210315
with self.subTest(f"pytree implement: {name}"):

test/dynamo/test_trace_rules.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,16 @@ def _check_set_equality(self, generated, used, rule_map, ignored_set):
323323
# or loaded in case there is typo in the strings.
324324
def test_skipfiles_inlinelist(self):
325325
for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST):
326-
self.assertTrue(
327-
isinstance(importlib.import_module(m), types.ModuleType),
328-
f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
329-
)
326+
try:
327+
mod = importlib.import_module(m)
328+
except ImportError:
329+
continue
330+
else:
331+
self.assertTrue(
332+
isinstance(mod, types.ModuleType),
333+
f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST "
334+
"is not a python module, please check and correct it.",
335+
)
330336

331337
@unittest.skip(
332338
"This test keeps getting broken and our disable infra is not handling well. see #120627"

torch/_dynamo/guards.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,10 +2090,11 @@ def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None)
20902090
obj_ref = None
20912091
# Not necessary to have weakref for Enum type, but there is a bug that
20922092
# makes hasattr(guarded_object.__class__, "__weakref__") return True.
2093+
supports_weakref = (
2094+
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
2095+
)
20932096
# See D64140537 for why we are checking for tuple.
2094-
if hasattr(guarded_object.__class__, "__weakref__") and not isinstance(
2095-
guarded_object, (enum.Enum, tuple)
2096-
):
2097+
if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)):
20972098
obj_ref = weakref.ref(guarded_object)
20982099

20992100
guard.set_export_info(

torch/_dynamo/polyfills/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
itertools as itertools,
2424
operator as operator,
2525
os as os,
26+
pytree as pytree,
2627
sys as sys,
2728
)
2829

torch/_dynamo/polyfills/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"itertools",
1919
"operator",
2020
"os",
21+
"pytree",
2122
"sys",
2223
)
2324
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(

torch/_dynamo/polyfills/pytree.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Python polyfills for torch.utils.pytree
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import Any, Callable, Iterable, TYPE_CHECKING
8+
9+
import torch.utils._pytree as python_pytree
10+
11+
from ..decorators import substitute_in_graph
12+
13+
14+
if TYPE_CHECKING:
15+
from torch.utils._cxx_pytree import PyTree
16+
17+
18+
__all__: list[str] = []
19+
20+
21+
if python_pytree._cxx_pytree_exists:
22+
import optree
23+
import optree._C
24+
25+
import torch.utils._cxx_pytree as cxx_pytree
26+
27+
@substitute_in_graph(
28+
optree._C.is_dict_insertion_ordered,
29+
can_constant_fold_through=True,
30+
)
31+
def _(*args: Any, **kwargs: Any) -> bool:
32+
# In namespace 'torch', the dictionary is always traversed in insertion order.
33+
# This function returns True.
34+
raise ValueError(
35+
"Should not be called directly "
36+
"because the original function will be called in the constant fold path."
37+
)
38+
39+
__name = ""
40+
for __name in (
41+
"is_namedtuple",
42+
"is_namedtuple_class",
43+
"is_namedtuple_instance",
44+
"is_structseq",
45+
"is_structseq_class",
46+
"is_structseq_instance",
47+
"namedtuple_fields",
48+
"structseq_fields",
49+
):
50+
__func = getattr(optree, __name)
51+
substitute_in_graph(__func, can_constant_fold_through=True)(
52+
__func.__python_implementation__
53+
)
54+
del __func
55+
del __name
56+
57+
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
58+
def tree_is_leaf(
59+
tree: PyTree,
60+
is_leaf: Callable[[PyTree], bool] | None = None,
61+
) -> bool:
62+
if tree is None or (is_leaf is not None and is_leaf(tree)):
63+
return True
64+
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
65+
return True
66+
return False
67+
68+
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
69+
def tree_iter(
70+
tree: PyTree,
71+
is_leaf: Callable[[PyTree], bool] | None = None,
72+
) -> Iterable[Any]:
73+
stack = [tree]
74+
while stack:
75+
node = stack.pop()
76+
if tree_is_leaf(node, is_leaf=is_leaf):
77+
yield node
78+
continue
79+
80+
children, *_ = optree.tree_flatten_one_level(
81+
node,
82+
is_leaf=is_leaf,
83+
none_is_leaf=True,
84+
namespace="torch",
85+
)
86+
stack.extend(reversed(children))
87+
88+
__all__ += ["tree_iter"]
89+
90+
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
91+
def tree_leaves(
92+
tree: PyTree,
93+
is_leaf: Callable[[PyTree], bool] | None = None,
94+
) -> list[Any]:
95+
return list(tree_iter(tree, is_leaf=is_leaf))
96+
97+
__all__ += ["tree_leaves"]

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3311,6 +3311,7 @@ def _module_dir(m: types.ModuleType):
33113311
"torch.testing",
33123312
"torch.utils._content_store",
33133313
"torch.utils._contextlib",
3314+
"torch.utils._cxx_pytree",
33143315
"torch.utils._device",
33153316
"torch.utils._foreach_utils",
33163317
"torch.utils._python_dispatch",

0 commit comments

Comments
 (0)