Skip to content

Commit 0efd184

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Fix side effects for range iterator that escapes the graph (pytorch#141716)
`wrap_range_iterator` mistakenly used `ValueMutationNew`, when it should've used `ValueMutationExisting`, because this code path always has a source. Pull Request resolved: pytorch#141716 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141713, pytorch#141714, pytorch#141715, pytorch#141902
1 parent 7c3c8a6 commit 0efd184

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,6 +1800,17 @@ def run(n):
18001800
self.assertTrue(same(res2, torch.ones(2)))
18011801
self.assertTrue(same(res3, torch.ones(3)))
18021802

1803+
def test_range_iter_side_effects(self):
1804+
@torch.compile(backend="eager", fullgraph=True)
1805+
def run(x, it):
1806+
n = next(it)
1807+
return x + n
1808+
1809+
it = iter(range(1, 3))
1810+
res = run(torch.zeros(1), it)
1811+
self.assertTrue(same(res, torch.ones(1)))
1812+
self.assertEqual(next(it), 2)
1813+
18031814
def test_build_tuple_unpack(self):
18041815
def fn1(a, b, c):
18051816
return a - b / c

torch/_dynamo/side_effects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def codegen_update_mutated(self, cg: PyCodegen):
748748
cg(value)
749749
cg(var.source)
750750
suffixes.append([create_instruction("STORE_ATTR", argval=name)])
751-
elif isinstance(var, variables.TupleIteratorVariable):
751+
elif isinstance(var, variables.ListIteratorVariable):
752752
for _ in range(var.index):
753753
cg.add_push_null(
754754
lambda: cg.load_import_from(utils.__name__, "iter_next")

torch/_dynamo/variables/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,8 @@ def wrap_range_iterator(self, value: range_iterator):
13501350
# Get all the values from the range iterator; no need to install guards
13511351
# on items since `RANGE_ITERATOR_MATCH` guarantees the same items.
13521352
items = [ConstantVariable.create(v) for v in copy.deepcopy(value)]
1353-
return ListIteratorVariable(items, mutation_type=ValueMutationNew())
1353+
result = ListIteratorVariable(items, source=self.source)
1354+
return self.tx.output.side_effects.track_mutable(value, result)
13541355

13551356
def wrap_slice_range(self, value: Union[slice, range]):
13561357
items = [

0 commit comments

Comments
 (0)