Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

empty_permute decomposition #2698

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten.empty_like,
aten._euclidean_dist.default,
aten.expand_as,
aten.eye,
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ def var_decomposition(
return variance


@register_torch_trt_decomposition(
torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
empty_size = args[0]
empty_permute = args[1]
perm = [0] * len(empty_size)
for permute_index, permute_element in enumerate(empty_permute):
perm[permute_element] = permute_index
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
65 changes: 65 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,71 @@ def forward(self, x):
f"MaxPool3d TRT outputs don't match with the original model.",
)

def test_lowering_empty_like_module(self):
class emptyLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
c = torch.ops.aten.add(x, x)
y = torch.ops.aten.empty_like.default(c)
d = y + c
return d

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.add.Tensor}
unexpected_ops = {
torch.ops.aten.empty_like.default,
torch.ops.aten.empty_permuted.default,
}

inputs = [torch.zeros(3, 2).cuda()]

fx_graph = torch.fx.symbolic_trace(emptyLike())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)
Comment on lines +443 to +450
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show a printout of what the original and final graphs look like in this case? I want to verify that there is not a circular issue where empty_permuted generates empty_like, and vice versa

Copy link
Collaborator Author

@apbose apbose Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the empty_permute decomposition the graph is this
Pre-AOT Autograd graph:=============

graph():
   %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
   %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})
   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})
   %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})
   return (add_1,)

Post-AOT Autograd graph:=======

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})
   %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([3, 2],), kwargs = {dtype: torch.float32,
 layout: torch.strided, device: cuda:0, pin_memory: False})
   %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1]), kwargs = {})
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute, %add), kwargs = {})
   return (add_1,)

Graph after constant folding:

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
   return (add_1,)

Post-lowering passes Autograd graph:=======

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
   return (add_1,)

Without the decomposition, the graph is
Pre-AOT Autograd graph:=============

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})
    %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})
    return (add_1,)

Post-AOT Autograd graph:=======

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dt
ype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%empty_permuted, %add), kwargs = {})
    return (add_1,)

Graph after constant folding:

graph():

    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
    return (add_1,)

Post-lowering passes Autograd graph:=======

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
    return (add_1,)

So empty_like decomposes into empty_permute which decomposes into empty.memory_format. The above test does not give error, even though empty.memory_format is not supported since constant folding removes the op.

I am working on empty.memory_format in PR #2745

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the above example, the Pre-AOT graph shows:

   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})

Since there is only one argument in args, what is empty_permute = args[1] defined as in the decomposition for that case?

Copy link
Collaborator Author

@apbose apbose Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the above case with the AOT decomposition, the above operation decomposes to

 %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})

The args[1] in this case is[0,1] since it keeps the shapes in the original form.
Not sure how it gets the [0,1] exact, but I assume it must be the internal AOT lowering heuristics?


self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading