Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add decomposition for aten.addmm #1953

Merged
merged 1 commit into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,14 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS)
def addmm_replacement(
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
) -> torch.Tensor:
return torch.add(
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
)


def get_decompositions():
return DECOMPOSITIONS
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from torch_tensorrt.dynamo import compile
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT


class TestLowering(TestCase):
Expand Down Expand Up @@ -109,6 +111,74 @@ def forward(self, x):
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

def test_lowering_addmm(self):
class AddMM(torch.nn.Module):
def forward(self, x, y, z):
return torch.addmm(x, y, z, beta=16, alpha=5)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mm.default,
}
unexpected_ops = {torch.ops.aten.addmm.default}

inputs = [
torch.rand(
1,
1,
).cuda(),
torch.rand(
7,
8,
).cuda(),
torch.rand(
8,
9,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(AddMM())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"AddMM TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/common_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

COSINE_THRESHOLD = 0.99
DECIMALS_OF_AGREEMENT = 5


def cosine_similarity(gt_tensor, pred_tensor):
Expand Down