Skip to content

Commit

Permalink
chore: Cast condition if type is not bool
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Jun 7, 2024
1 parent 6addbc3 commit 225d069
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
cast_trt_tensor,
get_trt_tensor,
prepend_ones,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
from torch_tensorrt.fx.types import TRTTensor


Expand All @@ -32,8 +35,12 @@ def where(
max_shape_len = max(len(x_shape), len(y_shape), len(condition_shape))

if not isinstance(condition, TRTTensor):
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
condition = get_trt_tensor(ctx, condition, f"{name}_condition")

if condition.dtype != trt.bool:
condition = cast_trt_tensor(ctx, condition, trt.float32, f"{name}_cast")
condition = ne(ctx, target, source_ir, f"{name}_cond_zero", condition, 0)

diff = max_shape_len - len(condition_shape)
if diff > 0:
condition = prepend_ones(ctx, condition, f"{name}_condition_broadcast", diff)
Expand Down

0 comments on commit 225d069

Please sign in to comment.