-
Notifications
You must be signed in to change notification settings - Fork 349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Handle dynamic shapes in where ops #2853
Conversation
f72e842
to
00d1d85
Compare
py/torch_tensorrt/_Input.py
Outdated
@@ -364,9 +364,13 @@ def example_tensor( | |||
) | |||
|
|||
if isinstance(self.shape, dict): | |||
return torch.rand(self.shape[optimization_profile_field]).to( | |||
dtype=self.dtype.to(torch.dtype, use_default=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If dtype is torch.bool in input_spec, torch.rand() returns random numbers on [0,1) and casting it to bool value is True
fc6499e
to
5e6f3bd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please mark torch.where converter in aten_ops_converters.py with supports_dynamic_shapes=True
flag. Example : https://github.com/pytorch/TensorRT/blob/release/2.3/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py#L62
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have left some comments. It looks like there aren't any major issues with the functionality.
5e6f3bd
to
ac4bf90
Compare
|
||
def get_axes_for_reduce_op( | ||
dim: Union[int, Sequence[int]], | ||
has_implicit_batch_dimension: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default parameter was used to merge below function
get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)
Looks good to me! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peri044 I think it's reasonable to deprecate FX converter_utils in Dynamo converter implementation. In my opinion, when copying these helper functions from FX utils, it would be better to change the first arg from network: TRTNetwork
to ctx: ConversionContext
for two reasons: 1) consistent with other helper functions 2) convenient to call them, just passing in ctx
.
|
||
def get_axes_for_reduce_op( | ||
dim: Union[int, Sequence[int]], | ||
has_implicit_batch_dimension: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peri044 do we still have has_implicit_batch_dimension
? Is it possible to remove the arg in the dynamo's converter_utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We should remove
has_implicit_batch_dimension
. -
It would be better to change the first arg from network: TRTNetwork to ctx: ConversionContext.
Yes we should make these changes and not use FX code/ data structures as much as possible. If we use them, we should be consistent with dynamo APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks. Removed has_implicit_batch_dimension flag in get_axes_for_reduce_op()
first arg of boardcast()/prepend_ones() is ctx: ConversionContext, instead of network: TRTNetwork
f980697
to
a12754b
Compare
a12754b
to
070a03a
Compare
This reverts commit b79bae8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. added a few more minor changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Dynamic shape cannot be used in where ops because of exception in torch.broadcast_shapes().
Proposed fix removes expand() for static shape input and only performs prepend ones to have same rank size for static/dynamic shape input. I think broadcast is not required as it is applied to ISelectLayer(addSelect)
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: