Skip to content

Commit

Permalink
refactor: Reorging to reduce code duplication and seperating TRT impl…
Browse files Browse the repository at this point in the history
…ementation, example changes with ReLU

Signed-off-by: Naren Dasan <naren@narendasan.com>
  • Loading branch information
narendasan committed Apr 28, 2023
1 parent c5cc6e3 commit dfd98a5
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 87 deletions.
6 changes: 3 additions & 3 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trt_transposed_matmul,
)
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
from torch_tensorrt.fx.converters.impl import activation

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1004,9 +1005,8 @@ def acc_ops_relu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.RELU
return add_activation_layer(network, input_val, operation_type, target, name)

return activation.convert_relu(network, target, kwargs, name, SourceIR.ACC)


@tensorrt_converter(acc_ops.leaky_relu)
Expand Down
39 changes: 0 additions & 39 deletions py/torch_tensorrt/fx/converters/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,6 @@
from .converter_utils import mark_as_int8_layer


def common_activation(
network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name
):
layer = network.add_activation(input=input_val, type=activation_type)
layer.name = layer_name

if input_val.dynamic_range:
dyn_range = activation_dyn_range_fn(input_val.dynamic_range)
mark_as_int8_layer(layer, dyn_range)

return layer.get_output(0)


@tensorrt_converter(torch.nn.functional.relu)
@tensorrt_converter(torch.nn.modules.activation.ReLU)
def relu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0
input_val = kwargs["input"]

if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(
f"ReLU received input {input_val} that is not part "
"of the TensorRT region!"
)

def activation_dyn_range_fn(dyn_range):
return max(0, dyn_range[0]), max(0, dyn_range[1])

return common_activation(
network,
submod,
input_val,
trt.ActivationType.RELU,
activation_dyn_range_fn,
layer_name,
)


@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
def sigmoid(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -293,7 +294,9 @@ def aten_ops_relu(
kwargs_new = {
"input": args[0],
}
return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name)
return activation.convert_relu(
network, target, kwargs_new, name, source_ir=SourceIR.ATEN
)


@tensorrt_converter(torch.ops.aten.sub.Tensor)
Expand Down
73 changes: 29 additions & 44 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from enum import Enum, auto
import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
Expand All @@ -22,6 +23,26 @@
from ..utils import torch_dtype_from_trt


class SourceIR(Enum):
NN = auto()
ACC = auto()
ATEN = auto()
PRIM = auto()
UNKNOWN = auto()

def __str__(self):
if self == SourceIR.NN:
return "nn"
elif self == SourceIR.ACC:
return "acc"
elif self == SourceIR.ATEN:
return "aten"
elif self == SourceIR.PRIM:
return "prim"
else:
return "unknown_ir"


def get_trt_plugin(
plugin_name: str,
field_collection: List[TRTPluginFieldCollection],
Expand Down Expand Up @@ -77,7 +98,9 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
return dim


def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
def set_layer_name(
layer: TRTLayer, target: Target, name: str, source_ir: SourceIR = SourceIR.UNKNOWN
) -> None:
"""
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
Expand All @@ -87,7 +110,11 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
the node represents.
name (str): Consists of fx node.name with optional suffix.
"""
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"
target_name = (
f"{source_ir}_ops.{target}"
if isinstance(target, str)
else f"{source_ir}_ops.{target.__name__}"
)
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"


Expand Down Expand Up @@ -560,48 +587,6 @@ def add_unary_layer(
return layer.get_output(0)


def add_activation_layer(
network: TRTNetwork,
input_val: TRTTensor,
operation_type: trt.ActivationType,
target: Target,
name: str,
alpha: Optional[Any] = None,
beta: Optional[Any] = None,
) -> TRTTensor:
"""
Add a TensorRT Activation layer to `network`.
Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): Input to the activation op.
Must be a TensorRT tensor.
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
alpha (Optional[Any]): If not None, we will use it to set the alpha
attribute of the created TensorRT activation layer.
beta (Optional[Any]): If not None, we will use it to set the beta
attribute of the created TensorRT activation layer.
Returns:
The output of TensorRT Activation layer.
"""
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"{operation_type} received input {input_val} that is not part "
"of the TensorRT region!"
)
layer = network.add_activation(input_val, operation_type)
if alpha is not None:
layer.alpha = alpha
if beta is not None:
layer.beta = beta
set_layer_name(layer, target, name)
return layer.get_output(0)


def add_reduce_layer(
network: TRTNetwork,
target: Target,
Expand Down
Empty file.
87 changes: 87 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
import operator
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Argument, Target


from torch_tensorrt.fx.converters.converter_utils import mark_as_int8_layer
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.converters.converter_utils import SourceIR

from torch_tensorrt.fx.types import (
TRTNetwork,
TRTTensor,
)


def convert_activation(
network: TRTNetwork,
input_val: TRTTensor,
operation_type: trt.ActivationType,
target: Target,
name: str,
alpha: Optional[Any] = None,
beta: Optional[Any] = None,
dyn_range_fn: Optional[Callable[[float, float], Any]] = None,
source_ir: SourceIR = SourceIR.UNKNOWN,
) -> TRTTensor:
"""
Add a TensorRT Activation layer to `network`.
Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): Input to the activation op.
Must be a TensorRT tensor.
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
alpha (Optional[Any]): If not None, we will use it to set the alpha
attribute of the created TensorRT activation layer.
beta (Optional[Any]): If not None, we will use it to set the beta
attribute of the created TensorRT activation layer.
dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range
Returns:
The output of TensorRT Activation layer.
"""
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"{operation_type} received input {input_val} that is not part "
"of the TensorRT region!"
)
layer = network.add_activation(input_val, operation_type)
if alpha is not None:
layer.alpha = alpha
if beta is not None:
layer.beta = beta
set_layer_name(layer, target, name, source_ir)

if input_val.dynamic_range is not None:
dyn_range = dyn_range_fn(input_val.dynamic_range)
mark_as_int8_layer(layer, dyn_range)
return layer.get_output(0)


def convert_relu(
network: TRTNetwork,
target: Target,
kwargs: Dict[str, Any],
name: str,
source_ir: SourceIR = SourceIR.UNKNOWN,
):
input_val = kwargs["input"]
operation_type = trt.ActivationType.RELU

def relu_dyn_range_fn(dyn_range):
return max(0, dyn_range[0]), max(0, dyn_range[1])

return convert_activation(
network, input_val, operation_type, target, name, relu_dyn_range_fn, source_ir
)
24 changes: 24 additions & 0 deletions py/torch_tensorrt/fx/converters/nn_ops_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.converter_utils import SourceIR


@tensorrt_converter(torch.nn.functional.relu)
@tensorrt_converter(torch.nn.modules.activation.ReLU)
def relu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0

return activation.convert_relu(
network=network,
target="torch.nn.functional.relu",
kwargs=kwargs,
name=layer_name,
source_ir=SourceIR.NN,
)

0 comments on commit dfd98a5

Please sign in to comment.