Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

empty_permute decomposition #2698

Merged
merged 1 commit into from
Apr 17, 2024
Merged

empty_permute decomposition #2698

merged 1 commit into from
Apr 17, 2024

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Mar 19, 2024

This is an extension to support aten::empty_like.

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 19, 2024
@github-actions github-actions bot requested a review from gs-olive March 19, 2024 21:02
@apbose apbose force-pushed the empty_permuted_decomposition branch from dcfe61d to 6abe7ce Compare April 5, 2024 00:15
Comment on lines +443 to +450
fx_graph = torch.fx.symbolic_trace(emptyLike())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show a printout of what the original and final graphs look like in this case? I want to verify that there is not a circular issue where empty_permuted generates empty_like, and vice versa

Copy link
Collaborator Author

@apbose apbose Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the empty_permute decomposition the graph is this
Pre-AOT Autograd graph:=============

graph():
   %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
   %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})
   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})
   %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})
   return (add_1,)

Post-AOT Autograd graph:=======

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})
   %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([3, 2],), kwargs = {dtype: torch.float32,
 layout: torch.strided, device: cuda:0, pin_memory: False})
   %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1]), kwargs = {})
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute, %add), kwargs = {})
   return (add_1,)

Graph after constant folding:

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
   return (add_1,)

Post-lowering passes Autograd graph:=======

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
   return (add_1,)

Without the decomposition, the graph is
Pre-AOT Autograd graph:=============

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})
    %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})
    return (add_1,)

Post-AOT Autograd graph:=======

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dt
ype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%empty_permuted, %add), kwargs = {})
    return (add_1,)

Graph after constant folding:

graph():

    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
    return (add_1,)

Post-lowering passes Autograd graph:=======

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
    return (add_1,)

So empty_like decomposes into empty_permute which decomposes into empty.memory_format. The above test does not give error, even though empty.memory_format is not supported since constant folding removes the op.

I am working on empty.memory_format in PR #2745

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the above example, the Pre-AOT graph shows:

   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})

Since there is only one argument in args, what is empty_permute = args[1] defined as in the decomposition for that case?

Copy link
Collaborator Author

@apbose apbose Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the above case with the AOT decomposition, the above operation decomposes to

 %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})

The args[1] in this case is[0,1] since it keeps the shapes in the original form.
Not sure how it gets the [0,1] exact, but I assume it must be the internal AOT lowering heuristics?

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me - added one clarifying question

Comment on lines +443 to +450
fx_graph = torch.fx.symbolic_trace(emptyLike())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the above example, the Pre-AOT graph shows:

   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})

Since there is only one argument in args, what is empty_permute = args[1] defined as in the decomposition for that case?

@apbose apbose merged commit 0b29987 into main Apr 17, 2024
16 of 21 checks passed
peri044 pushed a commit that referenced this pull request Apr 19, 2024
laikhtewari pushed a commit that referenced this pull request May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants