Skip to content

Commit 7c3c8a6

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Add RANGE_ITERATOR_MATCH to properly guard on range iterators (pytorch#141902)
A subsequeunt patch attempts to fix a side-effect issue for range iterators, which in turn exposed an exising issue on guards for range iterators -- the following test started failing: ``` PYTORCH_TEST_WITH_DYNAMO=1 python test/test_tensor_creation_ops.py TestTensorCreationCPU.test_hstack_column_stack_cpu_int16 ``` This patch adds a `RANGE_ITERATOR_MATCH` guard to make sure that we properly guard on range iterators, and adds a regression test. Pull Request resolved: pytorch#141902 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141713, pytorch#141714, pytorch#141715
1 parent ff3f4a1 commit 7c3c8a6

File tree

5 files changed

+146
-9
lines changed

5 files changed

+146
-9
lines changed

test/dynamo/test_misc.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,28 @@ def fn(a):
17781778
expected_ops=9,
17791779
)
17801780

1781+
def test_range_iter_guards(self):
1782+
@torch.compile()
1783+
def func():
1784+
@torch._dynamo.disable(recursive=False)
1785+
def run(n):
1786+
# For python <= 3.11, list comprehension is implemented by
1787+
# desugaring to:
1788+
# 1. creation of an iterator object
1789+
# 2. calling a new `listcomp` function with (1)
1790+
#
1791+
# In this test we force Dynamo to trace through (2) as the root
1792+
# frame, thereby ensuring we have the right guards for range
1793+
# iterators.
1794+
xs = [torch.ones(1) for i in range(n)]
1795+
return torch.concat(xs)
1796+
1797+
return run(2), run(3)
1798+
1799+
res2, res3 = func()
1800+
self.assertTrue(same(res2, torch.ones(2)))
1801+
self.assertTrue(same(res3, torch.ones(3)))
1802+
17811803
def test_build_tuple_unpack(self):
17821804
def fn1(a, b, c):
17831805
return a - b / c

torch/_dynamo/guards.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
istype,
128128
key_is_id,
129129
key_to_id,
130+
normalize_range_iter,
130131
orig_code_map,
131132
tensor_always_has_static_shape,
132133
tuple_iterator_getitem,
@@ -419,6 +420,7 @@ def _get_closure_vars():
419420
"___dict_version": dict_version,
420421
"___dict_contains": lambda a, b: a in b,
421422
"___tuple_iterator_len": tuple_iterator_len,
423+
"___normalize_range_iter": normalize_range_iter,
422424
"___tuple_iterator_getitem": tuple_iterator_getitem,
423425
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
424426
"__math_isnan": math.isnan,
@@ -1700,6 +1702,24 @@ def TUPLE_ITERATOR_LEN(self, guard):
17001702
tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard)
17011703
)
17021704

1705+
def RANGE_ITERATOR_MATCH(self, guard):
1706+
ref = self.arg_ref(guard)
1707+
value = self.get(guard.name)
1708+
t = type(value)
1709+
1710+
code = []
1711+
normalized_range_iter = normalize_range_iter(value)
1712+
code.append(f"___normalize_range_iter({ref}) == {normalized_range_iter}")
1713+
self._set_guard_export_info(guard, code)
1714+
1715+
t = type(value)
1716+
obj_id = self.id_ref(t, f"type({guard.name})")
1717+
1718+
start, stop, step = normalized_range_iter
1719+
self.get_guard_manager(guard).add_range_iterator_match_guard(
1720+
start, stop, step, obj_id, get_verbose_code_parts(code, guard)
1721+
)
1722+
17031723
# TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
17041724
def DUPLICATE_INPUT(self, guard, source_b):
17051725
ref_a = self.arg_ref(guard)

torch/_dynamo/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,16 @@ def tuple_iterator_getitem(it, index):
18261826
iter_next = next
18271827

18281828

1829+
def normalize_range_iter(range_iter) -> Tuple[int, int, int]:
1830+
_, (range_obj,), maybe_idx = range_iter.__reduce__()
1831+
# In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been
1832+
# already incremented by the current index.
1833+
start = range_obj.start + (maybe_idx or 0)
1834+
stop = range_obj.stop
1835+
step = range_obj.step
1836+
return (start, stop, step)
1837+
1838+
18291839
def to_subclass(t, cls):
18301840
return t.as_subclass(cls)
18311841

torch/_dynamo/variables/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,8 +1346,9 @@ def wrap_tuple_iterator(self, value: tuple_iterator):
13461346
return self.set_source_and_track_mutable(value, result)
13471347

13481348
def wrap_range_iterator(self, value: range_iterator):
1349-
self.install_guards(GuardBuilder.TYPE_MATCH)
1350-
# Get all the values from the range iterator
1349+
self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH)
1350+
# Get all the values from the range iterator; no need to install guards
1351+
# on items since `RANGE_ITERATOR_MATCH` guarantees the same items.
13511352
items = [ConstantVariable.create(v) for v in copy.deepcopy(value)]
13521353
return ListIteratorVariable(items, mutation_type=ValueMutationNew())
13531354

torch/csrc/dynamo/guards.cpp

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,20 @@
3232
#include <tuple>
3333
#include <utility>
3434

35-
// For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
36-
// underlying tuple and access the item. Before Python 3.12 version, the
37-
// datastructure is in tupleobject.c file -
35+
// Certain CPython data structures are defined in `.c` files in earlier Python
36+
// versions, e.g., for TupleIteratorGetItemAccessor, we need a fast way to
37+
// retrieve the underlying tuple and access the item. Before Python 3.12
38+
// version, the data structure is in tupleobject.c file -
3839
// https://github.com/python/cpython/blob/9afc6d102d16080535325f645849cd84eb04d57d/Objects/tupleobject.c#L1058-L1062
39-
// To handle this, we manually copy the struct here and manually cast it to this
40-
// new struct. From 3.12, the struct is included in the header file.
40+
//
41+
// To handle the older python versions, we manually copy the struct here and
42+
// manually cast it to this new struct. For newer versions, the struct is
43+
// included in the header file.
4144
#if IS_PYTHON_3_12_PLUS
4245

4346
#define Py_BUILD_CORE
44-
// Bring _PyTupleIterObject from the header file
45-
#include <internal/pycore_tuple.h>
47+
#include <internal/pycore_range.h> // _PyRangeIterObject
48+
#include <internal/pycore_tuple.h> // _PyTupleIterObject
4649
#undef Py_BUILD_CORE
4750

4851
#else
@@ -54,6 +57,19 @@ typedef struct {
5457
PyTupleObject* it_seq; /* Set to NULL when iterator is exhausted */
5558
} _PyTupleIterObject;
5659

60+
// Copied from CPython, and given a unified name for different Python verions.
61+
// https://github.com/python/cpython/blob/7f71003b222ad398713514c2b55d34dc05dba6bc/Objects/rangeobject.c#L765-L771
62+
typedef struct {
63+
PyObject_HEAD
64+
// NOTE for Python 3.12+, `index` is removed, and `start` is updated in place
65+
// instead, upon each `next(...)` call. See
66+
// https://github.com/python/cpython/pull/27986
67+
long index;
68+
long start;
69+
long step;
70+
long len;
71+
} _PyRangeIterObject;
72+
5773
#endif // IS_PYTHON_3_12_PLUS
5874

5975
namespace torch::dynamo {
@@ -1142,6 +1158,52 @@ class EQUALS_MATCH : public LeafGuard {
11421158
PyTypeObject* _value_type;
11431159
};
11441160

1161+
class RANGE_ITERATOR_MATCH : public LeafGuard {
1162+
public:
1163+
RANGE_ITERATOR_MATCH(
1164+
py::object start,
1165+
py::object stop,
1166+
py::object step,
1167+
py::object type_id,
1168+
py::object verbose_code_parts)
1169+
: LeafGuard(std::move(verbose_code_parts)),
1170+
_type_id(py::cast<intptr_t>(std::move(type_id))) {
1171+
PyObject* start_obj = start.ptr();
1172+
PyObject* stop_obj = stop.ptr();
1173+
PyObject* step_obj = step.ptr();
1174+
_start = THPUtils_unpackLong(start_obj);
1175+
_stop = THPUtils_unpackLong(stop_obj);
1176+
_step = THPUtils_unpackLong(step_obj);
1177+
TORCH_CHECK(
1178+
!PyErr_Occurred(), "values of start/stop/step must fit in a long type");
1179+
}
1180+
1181+
bool check_nopybind(PyObject* value) override { // borrowed ref
1182+
// Do a type match first.
1183+
// NOLINTNEXTLINE(performance-no-int-to-ptr)
1184+
if (Py_TYPE(value) != (void*)_type_id) {
1185+
return false;
1186+
}
1187+
_PyRangeIterObject* iter = (_PyRangeIterObject*)value;
1188+
1189+
#if IS_PYTHON_3_12_PLUS
1190+
long start = iter->start;
1191+
#else
1192+
long start = iter->start + iter->index * iter->step;
1193+
#endif // IS_PYTHON_3_12_PLUS
1194+
1195+
long stop = iter->start + iter->len * iter->step;
1196+
return start == _start && stop == _stop && iter->step == _step;
1197+
}
1198+
1199+
private:
1200+
intptr_t _type_id;
1201+
// Normalized representation of a range iterator.
1202+
long _start;
1203+
long _stop;
1204+
long _step;
1205+
};
1206+
11451207
class TUPLE_ITERATOR_LEN : public LeafGuard {
11461208
public:
11471209
TUPLE_ITERATOR_LEN(
@@ -4382,6 +4444,12 @@ PyObject* torch_c_dynamo_guards_init() {
43824444
std::shared_ptr<TUPLE_ITERATOR_LEN>>(py_m, "TUPLE_ITERATOR_LEN")
43834445
.def(py::init<py::object, py::object, py::list>())
43844446
.def("__call__", &TUPLE_ITERATOR_LEN::check);
4447+
py::class_<
4448+
RANGE_ITERATOR_MATCH,
4449+
LeafGuard,
4450+
std::shared_ptr<RANGE_ITERATOR_MATCH>>(py_m, "RANGE_ITERATOR_MATCH")
4451+
.def(py::init<py::object, py::object, py::object, py::object, py::list>())
4452+
.def("__call__", &RANGE_ITERATOR_MATCH::check);
43854453
py::class_<GLOBAL_STATE, LeafGuard, std::shared_ptr<GLOBAL_STATE>>(
43864454
py_m, "GLOBAL_STATE")
43874455
.def(py::init<py::list>())
@@ -4607,6 +4675,22 @@ PyObject* torch_c_dynamo_guards_init() {
46074675
std::move(type_id),
46084676
std::move(verbose_code_parts)));
46094677
})
4678+
.def(
4679+
"add_range_iterator_match_guard",
4680+
[](GuardManager& self,
4681+
py::object start,
4682+
py::object stop,
4683+
py::object step,
4684+
py::object type_id,
4685+
py::object verbose_code_parts) -> void {
4686+
SKIP_IF_GUARD_ALREADY_PRESENT("RANGE_ITERATOR_MATCH");
4687+
self.add_leaf_guard(std::make_shared<RANGE_ITERATOR_MATCH>(
4688+
std::move(start),
4689+
std::move(stop),
4690+
std::move(step),
4691+
std::move(type_id),
4692+
std::move(verbose_code_parts)));
4693+
})
46104694
.def(
46114695
"add_default_device_guard",
46124696
[](GuardManager& self, py::object verbose_code_parts) -> void {

0 commit comments

Comments
 (0)