@@ -431,11 +431,6 @@ def install_guards(self, *guards):
431
431
install_guard (* [source .make_guard (guard ) for guard in guards ], skip = 1 )
432
432
return {}
433
433
434
- def set_source_and_track_mutable (self , value , var ):
435
- assert isinstance (var , VariableTracker )
436
- var .source = self .source
437
- return self .tx .output .side_effects .track_mutable (value , var )
438
-
439
434
@classmethod
440
435
def _type_dispatch (cls ):
441
436
return cls ._type_dispatch_impl (config .trace_numpy )
@@ -607,7 +602,6 @@ def create_2d_tma_descriptor():
607
602
elif CustomizedDictVariable .is_matching_cls_hf (type (value )):
608
603
self .install_guards (GuardBuilder .TYPE_MATCH )
609
604
result = CustomizedDictVariable .wrap (self , value )
610
- result .source = self .source
611
605
return self .tx .output .side_effects .track_object_existing (value , result )
612
606
elif istype (value , (dict , collections .defaultdict , collections .OrderedDict )):
613
607
self .install_guards (GuardBuilder .SEQUENCE_LENGTH )
@@ -671,7 +665,7 @@ def build_key_value(i, k, v):
671
665
result , user_cls = type (value ), source = self .source
672
666
)
673
667
674
- return self .set_source_and_track_mutable (value , result )
668
+ return self .tx . output . side_effects . track_mutable (value , result )
675
669
elif isinstance (value , torch .nn .Module ):
676
670
return self .wrap_module (value )
677
671
elif ConstantVariable .is_literal (value ): # non-atomic literals
@@ -1137,7 +1131,7 @@ def build_key_value(i, k, v):
1137
1131
)
1138
1132
elif RestrictedListSubclassVariable .is_matching_cls (type (value )):
1139
1133
self .install_guards (GuardBuilder .SEQUENCE_LENGTH )
1140
- return self .set_source_and_track_mutable (
1134
+ return self .tx . output . side_effects . track_mutable (
1141
1135
value ,
1142
1136
RestrictedListSubclassVariable (
1143
1137
[
@@ -1148,6 +1142,7 @@ def build_key_value(i, k, v):
1148
1142
],
1149
1143
user_cls = type (value ),
1150
1144
user_cls_source = AttrSource (self .source , "__class__" ),
1145
+ source = self .source ,
1151
1146
),
1152
1147
)
1153
1148
elif TorchScriptObjectVariable .is_matching_cls (type (value )):
@@ -1326,9 +1321,9 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
1326
1321
)
1327
1322
tensor_list_proxy .node .meta ["grapharg" ] = grapharg
1328
1323
1329
- result = BaseListVariable .cls_for_instance (value )(output )
1324
+ result = BaseListVariable .cls_for_instance (value )(output , source = self . source )
1330
1325
if istype (value , (list , collections .deque )):
1331
- return self .set_source_and_track_mutable (value , result )
1326
+ return self .tx . output . side_effects . track_mutable (value , result )
1332
1327
return result
1333
1328
1334
1329
def wrap_tuple_iterator (self , value : tuple_iterator ):
@@ -1339,11 +1334,8 @@ def wrap_tuple_iterator(self, value: tuple_iterator):
1339
1334
)
1340
1335
for i in range (tuple_iterator_len (value ))
1341
1336
]
1342
- result = TupleIteratorVariable (
1343
- output , mutation_type = ValueMutationNew (), source = self .source
1344
- )
1345
-
1346
- return self .set_source_and_track_mutable (value , result )
1337
+ result = TupleIteratorVariable (output , source = self .source )
1338
+ return self .tx .output .side_effects .track_mutable (value , result )
1347
1339
1348
1340
def wrap_range_iterator (self , value : range_iterator ):
1349
1341
self .install_guards (GuardBuilder .RANGE_ITERATOR_MATCH )
@@ -1512,7 +1504,7 @@ def wrap_literal(self, value):
1512
1504
self .install_guards (GuardBuilder .CONSTANT_MATCH )
1513
1505
result = ConstantVariable .create (value = value , source = self .source )
1514
1506
if isinstance (value , (list , set )):
1515
- return self .set_source_and_track_mutable (value , result )
1507
+ return self .tx . output . side_effects . track_mutable (value , result )
1516
1508
return result
1517
1509
1518
1510
def assert_not_wrapped_by_this_graph (self , value : torch .Tensor ):
@@ -2403,7 +2395,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
2403
2395
elif istype (example_value , tuple ):
2404
2396
return TupleVariable (unpacked , ** options )
2405
2397
elif istype (example_value , (list , immutable_list )):
2406
- return ListVariable (unpacked , mutation_type = ValueMutationNew (), ** options )
2398
+ return ListVariable (unpacked , ** options )
2407
2399
else :
2408
2400
assert example_value .__class__ .__module__ == "torch.return_types" or hasattr (
2409
2401
example_value , "_fields"
0 commit comments