Skip to content

Commit

Permalink
Fix inference_mode (#885)
Browse files Browse the repository at this point in the history
Summary:
Fixes: #875

Test Plan:
Test locally with tutorials/quantize_vit/run_vit_b_quant.py
with:
```
with torch.inference_mode():
    benchmark_model(model, 20, inputs)
```

but can't repro the issue in unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Sep 13, 2024
1 parent 3fa38aa commit 90c8cbd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ def _register_aqt_quantized_linear_dispatches():

_register_aqt_quantized_linear_dispatches()

@implements(torch.nn.functional.linear)
@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def to(self, *args, **kwargs):

implements = LinearActivationQuantizedTensor.implements

@implements(torch.nn.functional.linear)
@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
Expand Down

0 comments on commit 90c8cbd

Please sign in to comment.