32
32
#include < tuple>
33
33
#include < utility>
34
34
35
- // For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
36
- // underlying tuple and access the item. Before Python 3.12 version, the
37
- // datastructure is in tupleobject.c file -
35
+ // Certain CPython data structures are defined in `.c` files in earlier Python
36
+ // versions, e.g., for TupleIteratorGetItemAccessor, we need a fast way to
37
+ // retrieve the underlying tuple and access the item. Before Python 3.12
38
+ // version, the data structure is in tupleobject.c file -
38
39
// https://github.com/python/cpython/blob/9afc6d102d16080535325f645849cd84eb04d57d/Objects/tupleobject.c#L1058-L1062
39
- // To handle this, we manually copy the struct here and manually cast it to this
40
- // new struct. From 3.12, the struct is included in the header file.
40
+ //
41
+ // To handle the older python versions, we manually copy the struct here and
42
+ // manually cast it to this new struct. For newer versions, the struct is
43
+ // included in the header file.
41
44
#if IS_PYTHON_3_12_PLUS
42
45
43
46
#define Py_BUILD_CORE
44
- // Bring _PyTupleIterObject from the header file
45
- #include < internal/pycore_tuple.h>
47
+ # include < internal/pycore_range.h > // _PyRangeIterObject
48
+ #include < internal/pycore_tuple.h> // _PyTupleIterObject
46
49
#undef Py_BUILD_CORE
47
50
48
51
#else
@@ -54,6 +57,19 @@ typedef struct {
54
57
PyTupleObject* it_seq; /* Set to NULL when iterator is exhausted */
55
58
} _PyTupleIterObject;
56
59
60
+ // Copied from CPython, and given a unified name for different Python verions.
61
+ // https://github.com/python/cpython/blob/7f71003b222ad398713514c2b55d34dc05dba6bc/Objects/rangeobject.c#L765-L771
62
+ typedef struct {
63
+ PyObject_HEAD
64
+ // NOTE for Python 3.12+, `index` is removed, and `start` is updated in place
65
+ // instead, upon each `next(...)` call. See
66
+ // https://github.com/python/cpython/pull/27986
67
+ long index;
68
+ long start;
69
+ long step;
70
+ long len;
71
+ } _PyRangeIterObject;
72
+
57
73
#endif // IS_PYTHON_3_12_PLUS
58
74
59
75
namespace torch ::dynamo {
@@ -1142,6 +1158,52 @@ class EQUALS_MATCH : public LeafGuard {
1142
1158
PyTypeObject* _value_type;
1143
1159
};
1144
1160
1161
+ class RANGE_ITERATOR_MATCH : public LeafGuard {
1162
+ public:
1163
+ RANGE_ITERATOR_MATCH (
1164
+ py::object start,
1165
+ py::object stop,
1166
+ py::object step,
1167
+ py::object type_id,
1168
+ py::object verbose_code_parts)
1169
+ : LeafGuard(std::move(verbose_code_parts)),
1170
+ _type_id (py::cast<intptr_t >(std::move(type_id))) {
1171
+ PyObject* start_obj = start.ptr ();
1172
+ PyObject* stop_obj = stop.ptr ();
1173
+ PyObject* step_obj = step.ptr ();
1174
+ _start = THPUtils_unpackLong (start_obj);
1175
+ _stop = THPUtils_unpackLong (stop_obj);
1176
+ _step = THPUtils_unpackLong (step_obj);
1177
+ TORCH_CHECK (
1178
+ !PyErr_Occurred (), " values of start/stop/step must fit in a long type" );
1179
+ }
1180
+
1181
+ bool check_nopybind (PyObject* value) override { // borrowed ref
1182
+ // Do a type match first.
1183
+ // NOLINTNEXTLINE(performance-no-int-to-ptr)
1184
+ if (Py_TYPE (value) != (void *)_type_id) {
1185
+ return false ;
1186
+ }
1187
+ _PyRangeIterObject* iter = (_PyRangeIterObject*)value;
1188
+
1189
+ #if IS_PYTHON_3_12_PLUS
1190
+ long start = iter->start ;
1191
+ #else
1192
+ long start = iter->start + iter->index * iter->step ;
1193
+ #endif // IS_PYTHON_3_12_PLUS
1194
+
1195
+ long stop = iter->start + iter->len * iter->step ;
1196
+ return start == _start && stop == _stop && iter->step == _step;
1197
+ }
1198
+
1199
+ private:
1200
+ intptr_t _type_id;
1201
+ // Normalized representation of a range iterator.
1202
+ long _start;
1203
+ long _stop;
1204
+ long _step;
1205
+ };
1206
+
1145
1207
class TUPLE_ITERATOR_LEN : public LeafGuard {
1146
1208
public:
1147
1209
TUPLE_ITERATOR_LEN (
@@ -4382,6 +4444,12 @@ PyObject* torch_c_dynamo_guards_init() {
4382
4444
std::shared_ptr<TUPLE_ITERATOR_LEN>>(py_m, " TUPLE_ITERATOR_LEN" )
4383
4445
.def (py::init<py::object, py::object, py::list>())
4384
4446
.def (" __call__" , &TUPLE_ITERATOR_LEN::check);
4447
+ py::class_<
4448
+ RANGE_ITERATOR_MATCH,
4449
+ LeafGuard,
4450
+ std::shared_ptr<RANGE_ITERATOR_MATCH>>(py_m, " RANGE_ITERATOR_MATCH" )
4451
+ .def (py::init<py::object, py::object, py::object, py::object, py::list>())
4452
+ .def (" __call__" , &RANGE_ITERATOR_MATCH::check);
4385
4453
py::class_<GLOBAL_STATE, LeafGuard, std::shared_ptr<GLOBAL_STATE>>(
4386
4454
py_m, " GLOBAL_STATE" )
4387
4455
.def (py::init<py::list>())
@@ -4607,6 +4675,22 @@ PyObject* torch_c_dynamo_guards_init() {
4607
4675
std::move (type_id),
4608
4676
std::move (verbose_code_parts)));
4609
4677
})
4678
+ .def (
4679
+ " add_range_iterator_match_guard" ,
4680
+ [](GuardManager& self,
4681
+ py::object start,
4682
+ py::object stop,
4683
+ py::object step,
4684
+ py::object type_id,
4685
+ py::object verbose_code_parts) -> void {
4686
+ SKIP_IF_GUARD_ALREADY_PRESENT (" RANGE_ITERATOR_MATCH" );
4687
+ self.add_leaf_guard (std::make_shared<RANGE_ITERATOR_MATCH>(
4688
+ std::move (start),
4689
+ std::move (stop),
4690
+ std::move (step),
4691
+ std::move (type_id),
4692
+ std::move (verbose_code_parts)));
4693
+ })
4610
4694
.def (
4611
4695
" add_default_device_guard" ,
4612
4696
[](GuardManager& self, py::object verbose_code_parts) -> void {
0 commit comments