-
Notifications
You must be signed in to change notification settings - Fork 349
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Reorging to reduce code duplication and seperating TRT impl…
…ementation, example changes with ReLU Signed-off-by: Naren Dasan <naren@narendasan.com>
- Loading branch information
1 parent
c5cc6e3
commit dfd98a5
Showing
7 changed files
with
147 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import numpy as np | ||
import operator | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | ||
|
||
# @manual=//deeplearning/trt/python:py_tensorrt | ||
import tensorrt as trt | ||
import torch | ||
from torch.fx.node import Argument, Target | ||
|
||
|
||
from torch_tensorrt.fx.converters.converter_utils import mark_as_int8_layer | ||
from torch_tensorrt.fx.converters.converter_utils import set_layer_name | ||
from torch_tensorrt.fx.converters.converter_utils import SourceIR | ||
|
||
from torch_tensorrt.fx.types import ( | ||
TRTNetwork, | ||
TRTTensor, | ||
) | ||
|
||
|
||
def convert_activation( | ||
network: TRTNetwork, | ||
input_val: TRTTensor, | ||
operation_type: trt.ActivationType, | ||
target: Target, | ||
name: str, | ||
alpha: Optional[Any] = None, | ||
beta: Optional[Any] = None, | ||
dyn_range_fn: Optional[Callable[[float, float], Any]] = None, | ||
source_ir: SourceIR = SourceIR.UNKNOWN, | ||
) -> TRTTensor: | ||
""" | ||
Add a TensorRT Activation layer to `network`. | ||
Args: | ||
network (TRTNetwork): TensorRT network object. | ||
input_val (TRTTensor): Input to the activation op. | ||
Must be a TensorRT tensor. | ||
op_type (trt.ElementWiseOperation): Type of the TensorRT activation | ||
operation. | ||
target (Target): Target of fx node. | ||
name (str): The name we want to assign to the created TensorRT layer. | ||
alpha (Optional[Any]): If not None, we will use it to set the alpha | ||
attribute of the created TensorRT activation layer. | ||
beta (Optional[Any]): If not None, we will use it to set the beta | ||
attribute of the created TensorRT activation layer. | ||
dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range | ||
Returns: | ||
The output of TensorRT Activation layer. | ||
""" | ||
if not isinstance(input_val, TRTTensor): | ||
raise RuntimeError( | ||
f"{operation_type} received input {input_val} that is not part " | ||
"of the TensorRT region!" | ||
) | ||
layer = network.add_activation(input_val, operation_type) | ||
if alpha is not None: | ||
layer.alpha = alpha | ||
if beta is not None: | ||
layer.beta = beta | ||
set_layer_name(layer, target, name, source_ir) | ||
|
||
if input_val.dynamic_range is not None: | ||
dyn_range = dyn_range_fn(input_val.dynamic_range) | ||
mark_as_int8_layer(layer, dyn_range) | ||
return layer.get_output(0) | ||
|
||
|
||
def convert_relu( | ||
network: TRTNetwork, | ||
target: Target, | ||
kwargs: Dict[str, Any], | ||
name: str, | ||
source_ir: SourceIR = SourceIR.UNKNOWN, | ||
): | ||
input_val = kwargs["input"] | ||
operation_type = trt.ActivationType.RELU | ||
|
||
def relu_dyn_range_fn(dyn_range): | ||
return max(0, dyn_range[0]), max(0, dyn_range[1]) | ||
|
||
return convert_activation( | ||
network, input_val, operation_type, target, name, relu_dyn_range_fn, source_ir | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
|
||
# @manual=//deeplearning/trt/python:py_tensorrt | ||
import tensorrt as trt | ||
import torch | ||
|
||
from torch_tensorrt.fx.converter_registry import tensorrt_converter | ||
from torch_tensorrt.fx.converters.impl import activation | ||
from torch_tensorrt.fx.converters.converter_utils import SourceIR | ||
|
||
|
||
@tensorrt_converter(torch.nn.functional.relu) | ||
@tensorrt_converter(torch.nn.modules.activation.ReLU) | ||
def relu(network, submod, args, kwargs, layer_name): | ||
# args/kwargs should have already been normalized to kwargs | ||
assert len(args) == 0 | ||
|
||
return activation.convert_relu( | ||
network=network, | ||
target="torch.nn.functional.relu", | ||
kwargs=kwargs, | ||
name=layer_name, | ||
source_ir=SourceIR.NN, | ||
) |