Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.resize_ converter #2874

Merged
merged 5 commits into from
Jun 19, 2024
Merged

feat: support aten.resize_ converter #2874

merged 5 commits into from
Jun 19, 2024

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented May 31, 2024

Description

Support converter for aten.resize_ operation: https://pytorch.org/docs/stable/generated/torch.Tensor.resize_.html#torch.Tensor.resize_

One critical aspect of this operation is handling cases where the target size (output size) is larger than the input tensor size. When the target size is larger than the input tensor size, the values of the additional elements are unpredictable. In PyTorch, these additional elements are not initialized, which can result in values that are close to zero but are not guaranteed to be zero.

image

In the converter developed for this PR, the additional elements are initialized to zero using numpy.zeros. While this approach ensures a predictable output, it does not exactly replicate the behavior of PyTorch, where the values of the additional elements are not initialized and can be unpredictable.

Fixes # (issue)

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 added the component: converters Issues re: Specific op converters label May 31, 2024
@chohk88 chohk88 self-assigned this May 31, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 31, 2024
Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just left some comments for your reference.

Although the additional elements are unpredictable but they are close enough to zero, as you said, I think it is able to pass our checkpoints since the tolerance is acceptable

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/shuffle.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/shuffle.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/shuffle.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/shuffle.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/shuffle.py Outdated Show resolved Hide resolved
tests/py/dynamo/conversion/test_resize_aten.py Outdated Show resolved Hide resolved
@chohk88
Copy link
Collaborator Author

chohk88 commented Jun 10, 2024

@apbose @zewenli98 The converter for resize_ has been fully implemented functionally.

However, while investigating the cause of the CI/CD failure, I found that tensors in uninitialized locations sometimes contain extremely large values, whereas the description always showed values very close to zero (in the range of 10^-40) or zero. Do you have any suggestions on how to handle tests in such cases?

image

@narendasan
Copy link
Collaborator

Use run_tests custom compare and check the original elements and the shape to verify correctness

@apbose
Copy link
Collaborator

apbose commented Jun 14, 2024

@chohk88 I was talking about the function run_test_custom_compare_results(). You can implement the attributes of the TRTTensor and the ref Tensor which you would want to compare in the comparators arg.

tests/py/dynamo/conversion/test_resize_aten.py Outdated Show resolved Hide resolved
):
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(output_trt, output_cpu, *args))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific case where the len(cuda_inputs) == 1 is required? In general I assume the len(cuda_inputs) would be 1 in most cases. And since it conditioned on res_trt could you highlight the cases where it would be required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment!

As you mentioned, in most cases, the length of cuda_inputs is 1. Below are the outputs when I set a breakpoint in the original code for two cases where torch.ops.aten.resize_.default(x, target_shape) returns tensors with shapes (3,) and (10, 15, 10).

image
image

In both cases, the length of cuda_inputs is also 1. When cuda_inputs has a length of 1, res_trt and res_cpu are not lists of length 1 but are torch.Tensors with shapes (3,) and (10, 15, 10) for each case. Therefore, when we use zip in for output_trt, output_cpu, comparator in zip(res_trt, res_cpu, comparators), and the comparators list has a length of 1, it results in comparing fewer elements because one dimension is lost from res_trt and res_cpu.

@chohk88 chohk88 linked an issue Jun 19, 2024 that may be closed by this pull request
@chohk88 chohk88 merged commit 950b791 into main Jun 19, 2024
55 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.resize_
6 participants