Skip to content

Commit

Permalink
Fix (tests): adapt test to new fx
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored and volcacius committed Jan 20, 2023
1 parent 8fc4035 commit 2aa8cca
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
31 changes: 21 additions & 10 deletions src/brevitas/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from torch.overrides import is_tensor_method_or_property

if torch_version < version.parse('1.8.1'):
from .backport.node import map_arg
from .backport.symbolic_trace import Tracer
from .backport.graph import Graph
from .backport.graph import magic_methods
from .backport.graph import reflectable_magic_methods
from .backport.graph import Target
from .backport.graph_module import GraphModule
from .backport.immutable_collections import immutable_dict
from .backport.immutable_collections import immutable_list
Expand All @@ -36,19 +38,29 @@
from .backport.symbolic_trace import map_aggregate
from .backport.symbolic_trace import Tracer
else:
from torch.fx import Graph
from torch.fx import GraphModule
from torch.fx import map_arg
from torch.fx import Tracer, Graph, GraphModule, Proxy, Node
from torch.fx import Node
from torch.fx import Proxy
from torch.fx import Tracer
from torch.fx.graph import magic_methods
from torch.fx.graph import reflectable_magic_methods
from torch.fx.graph import Target
from torch.fx.proxy import base_types
from torch.fx.graph import magic_methods, reflectable_magic_methods
try:
from torch.fx.immutable_collections import immutable_dict
from torch.fx.immutable_collections import immutable_list
from torch.fx.symbolic_trace import _autowrap_check
from torch.fx.symbolic_trace import _find_proxy
from torch.fx.symbolic_trace import _orig_module_call
from torch.fx.symbolic_trace import _orig_module_getattr
from torch.fx.symbolic_trace import _autowrap_check
from torch.fx.symbolic_trace import _Patcher, map_aggregate
from torch.fx.symbolic_trace import _wrapped_fns_to_patch, _wrapped_methods_to_patch
from torch.fx.symbolic_trace import _find_proxy, _patch_function, HAS_VARSTUFF
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.symbolic_trace import _patch_function
from torch.fx.symbolic_trace import _Patcher
from torch.fx.symbolic_trace import _wrapped_fns_to_patch
from torch.fx.symbolic_trace import _wrapped_methods_to_patch
from torch.fx.symbolic_trace import HAS_VARSTUFF
from torch.fx.symbolic_trace import map_aggregate
except ImportError:
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import _orig_module_getattr
Expand All @@ -58,7 +70,6 @@
from torch.fx._symbolic_trace import _find_proxy, _patch_function, HAS_VARSTUFF
from torch.fx.immutable_collections import immutable_dict, immutable_list


from .brevitas_tracer import brevitas_symbolic_trace
from .brevitas_tracer import brevitas_value_trace
from .brevitas_tracer import symbolic_trace
Expand Down
20 changes: 15 additions & 5 deletions tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def forward(self, x):
model = TestModel()
graph_model = symbolic_trace(model)
graph_model = FnToModule(torch.add, TestModel).apply(graph_model)
assert isinstance(graph_model.add_1, TestModel)
# Due to changes in fx after 1.8
attr_check = getattr(graph_model, 'add_1', None) or getattr(graph_model, 'add', None)
assert isinstance(attr_check, TestModel)


def test_rewriter_max_pool_to_module():
Expand All @@ -122,7 +124,9 @@ def forward(self, x):
graph_model = symbolic_trace(model)
graph_model = FnToModule(torch.max_pool2d, nn.MaxPool2d).apply(graph_model)
inp = torch.randn(2, 10, 10)
assert isinstance(graph_model.max_pool2d_1, nn.MaxPool2d)
# Due to changes in fx after 1.8
attr_check = getattr(graph_model, 'max_pool2d_1', None) or getattr(graph_model, 'max_pool2d', None)
assert isinstance(attr_check, nn.MaxPool2d)
assert (model(inp) == graph_model(inp)).all().item()


Expand All @@ -142,7 +146,9 @@ def forward(self, x):
graph_model = symbolic_trace(model)
graph_model = MethodToModule('add', AddModule).apply(graph_model)
inp = torch.randn(2, 10, 10)
assert isinstance(graph_model.add_1, AddModule)
# Due to changes in fx after 1.8
attr_check = getattr(graph_model, 'add_1', None) or getattr(graph_model, 'add', None)
assert isinstance(attr_check, AddModule)
assert (model(inp) == graph_model(inp)).all().item()


Expand All @@ -162,7 +168,9 @@ def forward(self, x):
graph_model = symbolic_trace(model)
graph_model = FnToModule(operator.add, AddModule).apply(graph_model)
inp = torch.randn(2, 10, 10)
assert isinstance(graph_model.add_1, AddModule)
# Due to changes in fx after 1.8
attr_check = getattr(graph_model, 'add_1', None) or getattr(graph_model, 'add', None)
assert isinstance(attr_check, AddModule)
assert (model(inp) == graph_model(inp)).all().item()


Expand All @@ -177,5 +185,7 @@ def forward(self, x):
graph_model = symbolic_trace(model)
graph_model = MeanMethodToAdaptiveAvgPool2d().apply(graph_model)
inp = torch.randn(2, 3, 10, 10)
assert isinstance(graph_model.mean_1, nn.AdaptiveAvgPool2d)
# Due to changes in fx after 1.8
attr_check = getattr(graph_model, 'mean_1', None) or getattr(graph_model, 'mean', None)
assert isinstance(attr_check, nn.AdaptiveAvgPool2d)
assert (model(inp) == graph_model(inp)).all().item()

0 comments on commit 2aa8cca

Please sign in to comment.