Skip to content

Commit

Permalink
support group norm, and improve batch and layer norms
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Sep 28, 2023
1 parent 9dc5e5d commit fd820e6
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 30 deletions.
46 changes: 46 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,52 @@ def aten_ops_layer_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default) # type: ignore[misc]
def aten_ops_native_group_norm(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.native_group_norm(
network,
target,
SourceIR.ATEN,
name,
input=args[0],
weight=args[1],
bias=args[2],
N=args[3],
C=args[4],
HxW=args[5],
group=args[6],
eps=args[7],
)


@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc]
def aten_ops_group_norm(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.group_norm(
network,
target,
SourceIR.ATEN,
name,
input=args[0],
num_groups=args[1],
weight=args_bounds_check(args, 2, None),
bias=args_bounds_check(args, 3, None),
eps=args_bounds_check(args, 4, 1e-05),
cudnn_enabled=args_bounds_check(args, 5, True),
)


def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)
Expand Down
136 changes: 107 additions & 29 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
get_trt_plugin,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_numpy,
Expand Down Expand Up @@ -188,79 +187,77 @@ def layer_norm_no_plugin(

shape = weight.shape
broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape
gamma = to_numpy(weight.reshape(*shape))
beta = to_numpy(bias.reshape(*shape))
gamma = to_numpy(weight).reshape(shape)
beta = to_numpy(bias).reshape(shape)

axes = 0
for d in range(len(shape)):
axes |= 1 << (len(input.shape) - d - 1)
dims = list(range(len(input.shape) - len(shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)

# E[x]
mean_expected_layer = network.add_reduce(
input, trt.ReduceOperation.AVG, axes, keep_dims=True
)
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected", source_ir)

# X-E[x]
sub_trt = convert_binary_elementwise(
# X - E[x]
sub_trt = impl.elementwise.sub(
network,
target,
source_ir,
f"{name}_sub",
trt.ElementWiseOperation.SUB,
input,
mean_expected_layer.get_output(0),
)
# Variance = mean(pow(x_sub_mean,2))

# variance = mean(pow(x_sub_mean, 2))
pow_tensor = network.add_constant(
(1,) * len(input.shape),
trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)),
)
pow_tensor.name = f"{name}_power"
pow_var = convert_binary_elementwise(
pow_var = impl.elementwise.pow(
network,
target,
source_ir,
f"{name}_pow_var",
trt.ElementWiseOperation.POW,
sub_trt,
pow_tensor.get_output(0),
)
mean_trt_layer = network.add_reduce(
pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
)
set_layer_name(mean_trt_layer, target, f"{name}_mean", source_ir)
# Variance + eps

# var + eps
eps_tensor = network.add_constant(
(1,) * len(input.shape),
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
)
eps_tensor.name = f"{name}_eps"
add_trt = convert_binary_elementwise(

# sqrt((var + eps))
add_trt = impl.elementwise.add(
network,
target,
source_ir,
f"{name}_add",
trt.ElementWiseOperation.SUM,
mean_trt_layer.get_output(0),
eps_tensor.get_output(0),
)
# SQRT((Var + eps))
sqrt_trt = convert_unary(
sqrt_trt = impl.unary.sqrt(
network,
target,
source_ir,
f"{name}_sqrt",
trt.UnaryOperation.SQRT,
add_trt,
)
# (x - E[x]) / sqrt((var + eps))
div_trt = convert_binary_elementwise(

# (X - E[X]) / sqrt((var + eps))
div_trt = impl.elementwise.div(
network,
target,
source_ir,
f"{name}_div_trt",
trt.ElementWiseOperation.DIV,
sub_trt,
sqrt_trt,
)
Expand All @@ -270,32 +267,113 @@ def layer_norm_no_plugin(
gamma.shape, trt.Weights(np.ascontiguousarray(gamma))
)
gamma_tensor.name = f"{name}_gamma"

assert beta is not None
beta_tensor = network.add_constant(
gamma.shape, trt.Weights(np.ascontiguousarray(beta))
)
beta_tensor.name = f"{name}_beta"

# y * gamma + beta
scale_layer = convert_binary_elementwise(
scaled_y = impl.elementwise.mul(
network,
target,
source_ir,
f"{name}_scale",
trt.ElementWiseOperation.PROD,
div_trt,
gamma_tensor.get_output(0),
)
return convert_binary_elementwise(
return impl.elementwise.add(
network,
target,
source_ir,
name,
trt.ElementWiseOperation.SUM,
scale_layer,
scaled_y,
beta_tensor.get_output(0),
)


def native_group_norm(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
N: int,
C: int,
HxW: int,
group: int,
eps: float,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return group_norm(
network,
target,
source_ir,
name,
input,
group,
weight,
bias,
eps,
cudnn_enabled=True,
)


def group_norm(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
num_groups: int,
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
eps: float,
cudnn_enabled: bool,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if not isinstance(input, trt.tensorrt.ITensor):
raise RuntimeError(
f"LayerNorm received input {input} that is not part "
"of the TensorRT region!"
)

if weight is None:
weight = to_numpy(1.0)

if bias is None:
bias = to_numpy(0.0)

scale = get_trt_tensor(network, weight, "scale")
bias = get_trt_tensor(network, bias, "bias")

eps_field = trt.PluginField(
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
)
num_groups_filed = trt.PluginField(
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32
)

field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed])

try:
# Here's the schema of the plugin:
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1")
except AssertionError:
_LOGGER.error(
"Unable to find group norm plugin, fall back to TensorRT implementation."
)

layer = network.add_plugin_v2([input, scale, bias], plugin)
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir)

# PyTorch requires three return values: (out, mean, rstd)
dummy_tensor = torch.tensor(0)
return layer.get_output(0), dummy_tensor, dummy_tensor


def softmax(
network: TRTNetwork,
target: Target,
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/op_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import operator
from typing import Dict, Sequence, Tuple, Union

import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

Expand All @@ -18,7 +19,8 @@ def getitem_validator(getitem_node: Node) -> bool:


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.detach.default) # type: ignore[misc]
def generic_evaluator(
network: TRTNetwork,
target: Target,
Expand Down
52 changes: 52 additions & 0 deletions tests/py/dynamo/conversion/test_group_norm_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestGroupNormConverter(DispatchTestCase):
def test_groupnorm(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.gn = torch.nn.GroupNorm(2, 6)

def forward(self, x):
return self.gn(x)

inputs = [torch.randn(1, 6, 224, 224)]
self.run_test(
TestModule(),
inputs,
expected_ops={torch.ops.aten.native_group_norm.default},
disable_passes=True,
)

def test_groupnorm_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.gn = torch.nn.GroupNorm(2, 6)

def forward(self, x):
return self.gn(x)

input_specs = [
Input(
shape=(-1, 6, 5),
dtype=torch.float32,
shape_ranges=[((2, 6, 5), (6, 6, 5), (10, 6, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
expected_ops={torch.ops.aten.native_group_norm.default},
disable_passes=True,
)


if __name__ == "__main__":
run_tests()

0 comments on commit fd820e6

Please sign in to comment.