Skip to content

Commit ff73e2e

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Validate mutation_type and source in VariableTracker.__init__ (pytorch#141717)
As title, this also uncovered a few invalid use cases; the cases that cause error are fixed in separate patches prior to this patch, and the rest are fixed in this patch. This patch also moves a few `.source` mutation to variable construction, to increase the coverage of the validation. Fixes pytorch#133027. Pull Request resolved: pytorch#141717 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141713, pytorch#141714, pytorch#141715, pytorch#141902, pytorch#141716
1 parent 0efd184 commit ff73e2e

File tree

6 files changed

+35
-24
lines changed

6 files changed

+35
-24
lines changed

torch/_dynamo/variables/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,23 @@ def __init__(
479479
self.source = source
480480
self.mutation_type = mutation_type
481481

482+
# NOTE sometimes mutation_type is set afterwards for implementation
483+
# convenience, we don't validate those cases at the moment.
484+
if mutation_type is not None:
485+
if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)):
486+
# If this fails, it's either
487+
# 1. one mistakenly passed in a source
488+
# 2. `mutation_type` is incorrect
489+
assert source is None
490+
else:
491+
assert isinstance(
492+
mutation_type, (ValueMutationExisting, AttributeMutationExisting)
493+
)
494+
# If this fails, it's either
495+
# 1. one forgot to pass in a source
496+
# 2. `mutation_type` is incorrect
497+
assert source is not None
498+
482499

483500
def typestr(*objs):
484501
if len(objs) == 1:

torch/_dynamo/variables/builder.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,6 @@ def install_guards(self, *guards):
431431
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
432432
return {}
433433

434-
def set_source_and_track_mutable(self, value, var):
435-
assert isinstance(var, VariableTracker)
436-
var.source = self.source
437-
return self.tx.output.side_effects.track_mutable(value, var)
438-
439434
@classmethod
440435
def _type_dispatch(cls):
441436
return cls._type_dispatch_impl(config.trace_numpy)
@@ -607,7 +602,6 @@ def create_2d_tma_descriptor():
607602
elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
608603
self.install_guards(GuardBuilder.TYPE_MATCH)
609604
result = CustomizedDictVariable.wrap(self, value)
610-
result.source = self.source
611605
return self.tx.output.side_effects.track_object_existing(value, result)
612606
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
613607
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
@@ -671,7 +665,7 @@ def build_key_value(i, k, v):
671665
result, user_cls=type(value), source=self.source
672666
)
673667

674-
return self.set_source_and_track_mutable(value, result)
668+
return self.tx.output.side_effects.track_mutable(value, result)
675669
elif isinstance(value, torch.nn.Module):
676670
return self.wrap_module(value)
677671
elif ConstantVariable.is_literal(value): # non-atomic literals
@@ -1137,7 +1131,7 @@ def build_key_value(i, k, v):
11371131
)
11381132
elif RestrictedListSubclassVariable.is_matching_cls(type(value)):
11391133
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
1140-
return self.set_source_and_track_mutable(
1134+
return self.tx.output.side_effects.track_mutable(
11411135
value,
11421136
RestrictedListSubclassVariable(
11431137
[
@@ -1148,6 +1142,7 @@ def build_key_value(i, k, v):
11481142
],
11491143
user_cls=type(value),
11501144
user_cls_source=AttrSource(self.source, "__class__"),
1145+
source=self.source,
11511146
),
11521147
)
11531148
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
@@ -1326,9 +1321,9 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
13261321
)
13271322
tensor_list_proxy.node.meta["grapharg"] = grapharg
13281323

1329-
result = BaseListVariable.cls_for_instance(value)(output)
1324+
result = BaseListVariable.cls_for_instance(value)(output, source=self.source)
13301325
if istype(value, (list, collections.deque)):
1331-
return self.set_source_and_track_mutable(value, result)
1326+
return self.tx.output.side_effects.track_mutable(value, result)
13321327
return result
13331328

13341329
def wrap_tuple_iterator(self, value: tuple_iterator):
@@ -1339,11 +1334,8 @@ def wrap_tuple_iterator(self, value: tuple_iterator):
13391334
)
13401335
for i in range(tuple_iterator_len(value))
13411336
]
1342-
result = TupleIteratorVariable(
1343-
output, mutation_type=ValueMutationNew(), source=self.source
1344-
)
1345-
1346-
return self.set_source_and_track_mutable(value, result)
1337+
result = TupleIteratorVariable(output, source=self.source)
1338+
return self.tx.output.side_effects.track_mutable(value, result)
13471339

13481340
def wrap_range_iterator(self, value: range_iterator):
13491341
self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH)
@@ -1512,7 +1504,7 @@ def wrap_literal(self, value):
15121504
self.install_guards(GuardBuilder.CONSTANT_MATCH)
15131505
result = ConstantVariable.create(value=value, source=self.source)
15141506
if isinstance(value, (list, set)):
1515-
return self.set_source_and_track_mutable(value, result)
1507+
return self.tx.output.side_effects.track_mutable(value, result)
15161508
return result
15171509

15181510
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
@@ -2403,7 +2395,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
24032395
elif istype(example_value, tuple):
24042396
return TupleVariable(unpacked, **options)
24052397
elif istype(example_value, (list, immutable_list)):
2406-
return ListVariable(unpacked, mutation_type=ValueMutationNew(), **options)
2398+
return ListVariable(unpacked, **options)
24072399
else:
24082400
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
24092401
example_value, "_fields"

torch/_dynamo/variables/builtin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,9 @@ def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
13751375
arg, user_cls, mutation_type=ValueMutationNew()
13761376
)
13771377
elif isinstance(arg, variables.ConstDictVariable):
1378-
return arg.clone(user_cls=user_cls, mutation_type=ValueMutationNew())
1378+
return arg.clone(
1379+
user_cls=user_cls, source=None, mutation_type=ValueMutationNew()
1380+
)
13791381
elif isinstance(
13801382
arg,
13811383
(

torch/_dynamo/variables/dicts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def wrap(cls, builder, obj):
845845
if val is not None:
846846
key = ConstantVariable.create(key)
847847
items[key] = var
848-
return cls(items, user_cls)
848+
return cls(items, user_cls, source=builder.source)
849849

850850
def __init__(self, items, user_cls, **options) -> None:
851851
super().__init__(items, user_cls, **options)

torch/_dynamo/variables/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,11 @@ def visit(node):
721721

722722
def call_backward(self, tx: "InstructionTranslator", args, kwargs):
723723
fn = self.fn_cls.backward
724-
self.source = AttrSource(self.source, "backward")
725724
assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
726725
assert isinstance(fn, types.FunctionType)
727726

728-
return variables.UserFunctionVariable(fn, source=self.source).call_function(
727+
fn_source = AttrSource(self.source, "backward")
728+
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
729729
tx, args, kwargs
730730
)
731731

torch/_dynamo/variables/torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True):
11171117
(data.as_proxy(), placeholder.as_proxy()),
11181118
{},
11191119
),
1120+
# In reconstruct() we should use the original parameter. The one
1121+
# returned by the graph will be an alias.
1122+
source=placeholder.source,
11201123
)
11211124
assert isinstance(result, variables.TensorVariable)
11221125
result.class_type = torch.nn.Parameter
@@ -1127,9 +1130,6 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True):
11271130
# has_grad_fn field to False to workaround the issue.
11281131
result.has_grad_fn = False
11291132

1130-
# In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
1131-
result.source = placeholder.source
1132-
11331133
# TODO(jansel): if the new param falls out of scope, currently it won't get freed until
11341134
# the end of the graph. We should fix this.
11351135
return result

0 commit comments

Comments
 (0)