Skip to content

Commit

Permalink
chore: Remove has_implicit_batch_dimension flag in get_axes_for_reduc…
Browse files Browse the repository at this point in the history
…e_op()
  • Loading branch information
keehyuna committed Jun 7, 2024
1 parent 225d069 commit 070a03a
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def broadcast(

def get_axes_for_reduce_op(
dim: Union[int, Sequence[int]],
has_implicit_batch_dimension: bool = False,
) -> int:
"""
TensorRT reduce layer relies on the binary representation of axes to
Expand All @@ -736,8 +735,6 @@ def get_axes_for_reduce_op(
Args:
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
that will be used to generate axes for TensorRT.
has_implicit_batch_dimension (bool): Whether the TensorRT network is
using implicit batch dimension.
Returns:
An integer which binary form can be used as axes for TensorRT reduce
Expand All @@ -746,12 +743,9 @@ def get_axes_for_reduce_op(
if isinstance(dim, int):
dim = (dim,)

if has_implicit_batch_dimension:
assert 0 not in dim, "Can't reduce over batch dimension when it's implicit."

axes = 0
for d in dim:
axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0))
axes |= 1 << d

return axes

Expand Down

0 comments on commit 070a03a

Please sign in to comment.