Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorg for converters in (FX Converter Refactor [1/N]) #1867

Merged
merged 1 commit into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 59 additions & 17 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trt_transposed_matmul,
)
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
from torch_tensorrt.fx.converters.impl import activation

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1004,9 +1005,14 @@ def acc_ops_relu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.RELU
return add_activation_layer(network, input_val, operation_type, target, name)

return activation.relu(
network,
target,
SourceIR.ACC,
name,
kwargs["input"],
)


@tensorrt_converter(acc_ops.leaky_relu)
Expand All @@ -1020,8 +1026,14 @@ def acc_ops_leaky_relu(
input_val = kwargs["input"]
negative_slope = kwargs["negative_slope"]
operation_type = trt.ActivationType.LEAKY_RELU
return add_activation_layer(
network, input_val, operation_type, target, name, negative_slope
return activation.convert_activation(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
alpha=negative_slope,
)


Expand All @@ -1036,7 +1048,9 @@ def acc_ops_elu(
input_val = kwargs["input"]
alpha = kwargs["alpha"]
operation_type = trt.ActivationType.ELU
return add_activation_layer(network, input_val, operation_type, target, name, alpha)
return activation.convert_activation(
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha
)


@tensorrt_converter(acc_ops.selu)
Expand All @@ -1049,7 +1063,14 @@ def acc_ops_selu(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.SELU
return add_activation_layer(network, input_val, operation_type, target, name)
return activation.convert_activation(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
)


@tensorrt_converter(acc_ops.softsign)
Expand All @@ -1062,7 +1083,14 @@ def acc_ops_softsign(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.SOFTSIGN
return add_activation_layer(network, input_val, operation_type, target, name)
return activation.convert_activation(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
)


@tensorrt_converter(acc_ops.sin)
Expand Down Expand Up @@ -1140,7 +1168,14 @@ def acc_ops_tanh(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.TANH
return add_activation_layer(network, input_val, operation_type, target, name)
return activation.convert_activation(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
)


@tensorrt_converter(acc_ops.asin)
Expand Down Expand Up @@ -3137,12 +3172,13 @@ def acc_ops_hard_sigmoid(
"of the TensorRT region!"
)

return add_activation_layer(
return activation.convert_activation(
network,
input_val,
trt.ActivationType.HARD_SIGMOID,
target,
SourceIR.ACC,
name,
trt.ActivationType.HARD_SIGMOID,
input_val,
alpha=1 / 6,
beta=0.5,
)
Expand All @@ -3164,8 +3200,13 @@ def acc_ops_sigmoid(
"of the TensorRT region!"
)

return add_activation_layer(
network, input_val, trt.ActivationType.SIGMOID, target, name
return activation.convert_activation(
network,
target,
SourceIR.ACC,
name,
trt.ActivationType.SIGMOID,
input_val,
)


Expand Down Expand Up @@ -3557,12 +3598,13 @@ def acc_ops_hardtanh(
"of the TensorRT region!"
)

return add_activation_layer(
return activation.convert_activation(
network,
input_val,
trt.ActivationType.CLIP,
target,
SourceIR.ACC,
name,
trt.ActivationType.CLIP,
input_val,
alpha=kwargs["min_val"],
beta=kwargs["max_val"],
)
Expand Down
39 changes: 0 additions & 39 deletions py/torch_tensorrt/fx/converters/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,6 @@
from .converter_utils import mark_as_int8_layer


def common_activation(
network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name
):
layer = network.add_activation(input=input_val, type=activation_type)
layer.name = layer_name

if input_val.dynamic_range:
dyn_range = activation_dyn_range_fn(input_val.dynamic_range)
mark_as_int8_layer(layer, dyn_range)

return layer.get_output(0)


@tensorrt_converter(torch.nn.functional.relu)
@tensorrt_converter(torch.nn.modules.activation.ReLU)
def relu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0
input_val = kwargs["input"]

if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(
f"ReLU received input {input_val} that is not part "
"of the TensorRT region!"
)

def activation_dyn_range_fn(dyn_range):
return max(0, dyn_range[0]), max(0, dyn_range[1])

return common_activation(
network,
submod,
input_val,
trt.ActivationType.RELU,
activation_dyn_range_fn,
layer_name,
)


@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
def sigmoid(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
Expand Down
13 changes: 9 additions & 4 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -290,10 +291,14 @@ def aten_ops_relu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
}
return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name)

return activation.relu(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@tensorrt_converter(torch.ops.aten.sub.Tensor)
Expand Down
77 changes: 33 additions & 44 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from enum import Enum, auto
import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
Expand All @@ -22,6 +23,26 @@
from ..utils import torch_dtype_from_trt


class SourceIR(Enum):
NN = auto()
ACC = auto()
ATEN = auto()
PRIM = auto()
UNKNOWN = auto()

def __str__(self):
if self == SourceIR.NN:
return "nn"
elif self == SourceIR.ACC:
return "acc"
elif self == SourceIR.ATEN:
return "aten"
elif self == SourceIR.PRIM:
return "prim"
else:
return "unknown_ir"


def get_trt_plugin(
plugin_name: str,
field_collection: List[TRTPluginFieldCollection],
Expand Down Expand Up @@ -77,7 +98,9 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
return dim


def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
def set_layer_name(
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
) -> None:
"""
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"

Expand All @@ -86,8 +109,16 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
target (Target): A fx node.target. For call_function node, it's the function that
the node represents.
name (str): Consists of fx node.name with optional suffix.
source_ir: (Optional[SourceIR]): The IR producing the op.
"""
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"

source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN

target_name = (
f"{source_ir}_ops.{target}"
if isinstance(target, str)
else f"{source_ir}_ops.{target.__name__}"
)
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"


Expand Down Expand Up @@ -560,48 +591,6 @@ def add_unary_layer(
return layer.get_output(0)


def add_activation_layer(
network: TRTNetwork,
input_val: TRTTensor,
operation_type: trt.ActivationType,
target: Target,
name: str,
alpha: Optional[Any] = None,
beta: Optional[Any] = None,
) -> TRTTensor:
"""
Add a TensorRT Activation layer to `network`.

Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): Input to the activation op.
Must be a TensorRT tensor.
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
alpha (Optional[Any]): If not None, we will use it to set the alpha
attribute of the created TensorRT activation layer.
beta (Optional[Any]): If not None, we will use it to set the beta
attribute of the created TensorRT activation layer.

Returns:
The output of TensorRT Activation layer.
"""
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"{operation_type} received input {input_val} that is not part "
"of the TensorRT region!"
)
layer = network.add_activation(input_val, operation_type)
if alpha is not None:
layer.alpha = alpha
if beta is not None:
layer.beta = beta
set_layer_name(layer, target, name)
return layer.get_output(0)


def add_reduce_layer(
network: TRTNetwork,
target: Target,
Expand Down
Empty file.
Loading