Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle dynamic shapes in where ops #2853

Merged
merged 10 commits into from
Jun 12, 2024

Conversation

keehyuna
Copy link
Collaborator

@keehyuna keehyuna commented May 20, 2024

Description

Dynamic shape cannot be used in where ops because of exception in torch.broadcast_shapes().
Proposed fix removes expand() for static shape input and only performs prepend ones to have same rank size for static/dynamic shape input. I think broadcast is not required as it is applied to ISelectLayer(addSelect)

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 20, 2024
@github-actions github-actions bot requested a review from apbose May 20, 2024 08:43
@github-actions github-actions bot added the component: tests Issues re: Tests label May 20, 2024
@@ -364,9 +364,13 @@ def example_tensor(
)

if isinstance(self.shape, dict):
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If dtype is torch.bool in input_spec, torch.rand() returns random numbers on [0,1) and casting it to bool value is True

@keehyuna keehyuna requested a review from peri044 May 21, 2024 00:53
@keehyuna keehyuna self-assigned this May 21, 2024
@keehyuna keehyuna marked this pull request as ready for review May 24, 2024 07:14
@chohk88 chohk88 self-requested a review May 29, 2024 05:33
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please mark torch.where converter in aten_ops_converters.py with supports_dynamic_shapes=True flag. Example : https://github.com/pytorch/TensorRT/blob/release/2.3/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py#L62

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py Outdated Show resolved Hide resolved
tests/py/dynamo/conversion/test_where_aten.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@chohk88 chohk88 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have left some comments. It looks like there aren't any major issues with the functionality.


def get_axes_for_reduce_op(
dim: Union[int, Sequence[int]],
has_implicit_batch_dimension: bool = False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default parameter was used to merge below function
get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)

@chohk88
Copy link
Collaborator

chohk88 commented Jun 3, 2024

Looks good to me!

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 I think it's reasonable to deprecate FX converter_utils in Dynamo converter implementation. In my opinion, when copying these helper functions from FX utils, it would be better to change the first arg from network: TRTNetwork to ctx: ConversionContext for two reasons: 1) consistent with other helper functions 2) convenient to call them, just passing in ctx.


def get_axes_for_reduce_op(
dim: Union[int, Sequence[int]],
has_implicit_batch_dimension: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 do we still have has_implicit_batch_dimension? Is it possible to remove the arg in the dynamo's converter_utils?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should remove has_implicit_batch_dimension.
  2. It would be better to change the first arg from network: TRTNetwork to ctx: ConversionContext.

Yes we should make these changes and not use FX code/ data structures as much as possible. If we use them, we should be consistent with dynamo APIs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. Removed has_implicit_batch_dimension flag in get_axes_for_reduce_op()
first arg of boardcast()/prepend_ones() is ctx: ConversionContext, instead of network: TRTNetwork

py/torch_tensorrt/dynamo/conversion/converter_utils.py Outdated Show resolved Hide resolved
@keehyuna keehyuna force-pushed the where_dynamic_shape branch 2 times, most recently from f980697 to a12754b Compare June 6, 2024 01:25
@keehyuna
Copy link
Collaborator Author

Moved "chore: Better random bool values for example_tensor" change into PR2878 to combine random input changes to one PR.
436e99b

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. added a few more minor changes.

py/torch_tensorrt/dynamo/types.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@peri044 peri044 merged commit ac702b7 into pytorch:main Jun 12, 2024
46 of 51 checks passed
@keehyuna keehyuna deleted the where_dynamic_shape branch August 19, 2024 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants