Skip to content

Commit ff3f4a1

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Fix aliasing issue for dict.copy that escapes the graph (pytorch#141715)
Dynamo accidentally passed the original `ConstDictVariable.source` to the result of `dict.copy(...)`, which caused aliasing issue when the result escapes the graph (e.g., is a return value). This patch fixes that and adds a regression test. Pull Request resolved: pytorch#141715 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141713, pytorch#141714
1 parent 9eb0520 commit ff3f4a1

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/dynamo/test_misc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2833,6 +2833,19 @@ def fn(d):
28332833
self.assertEqual(cnts.frame_count, 1)
28342834
self.assertEqual(cnts.op_count, 1)
28352835

2836+
def test_dict_copy_alias(self):
2837+
@torch.compile(backend="eager", fullgraph=True)
2838+
def run(x, d0):
2839+
d1 = d0.copy()
2840+
d1[0] = 1
2841+
return x + 1, d1
2842+
2843+
d0 = {}
2844+
res, d1 = run(torch.zeros(1), d0)
2845+
self.assertTrue(same(res, torch.ones(1)))
2846+
self.assertEqual(d0, {})
2847+
self.assertEqual(d1, {0: 1})
2848+
28362849
def test_dict_subclass_get_method(self):
28372850
class dotdict(dict):
28382851
"""dot.notation access to dictionary attributes"""

torch/_dynamo/variables/dicts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ def call_method(
312312
return DictValues(self)
313313
elif name == "copy":
314314
assert not (args or kwargs)
315-
return self.clone(items=self.items.copy(), mutation_type=ValueMutationNew())
315+
return self.clone(
316+
items=self.items.copy(), mutation_type=ValueMutationNew(), source=None
317+
)
316318
elif name == "__len__":
317319
assert not (args or kwargs)
318320
return ConstantVariable.create(len(self.items))

0 commit comments

Comments
 (0)