Skip to content

Commit

Permalink
fix: convert_module_to_trt_engine (#2728)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored and peri044 committed Apr 26, 2024
1 parent bec91fb commit 08f1636
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
2 changes: 2 additions & 0 deletions docsrc/py_api/dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Functions

.. autofunction:: export

.. autofunction:: convert_module_to_trt_engine



Classes
Expand Down
16 changes: 12 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import collections.abc
import logging
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set
Expand Down Expand Up @@ -237,8 +238,6 @@ def compile(
return compiled_fx_module
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
import collections.abc

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(input_list, collections.abc.Sequence):
Expand Down Expand Up @@ -342,10 +341,19 @@ def convert_method_to_trt_engine(
"convert_method_to_trt_engine call is not supported for ir=fx"
)
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

# Export the module
torchtrt_inputs = prepare_inputs(inputs)
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)

return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
module,
exp_program,
inputs=inputs,
method_name=method_name,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down
50 changes: 21 additions & 29 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:


def convert_module_to_trt_engine(
module: torch.fx.GraphModule,
method_name: str = "forward",
exported_program: ExportedProgram,
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
Expand Down Expand Up @@ -453,15 +452,15 @@ def convert_module_to_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
) -> bytes:
"""Convert a GraphModule module method to a serialized TensorRT engine
"""Convert an ExportedProgram to a serialized TensorRT engine
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings
Arguments:
module (torch.fx.GraphModule): Source module
exported_program (torch.export.ExportedProgram): Source module
Keyword Args:
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
to select device type. ::
Expand All @@ -476,30 +475,11 @@ def convert_module_to_trt_engine(
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
method_name (str): Name of method to convert
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
input_signature=([
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool): Provide version forward-compatibility for engine plan files
Expand Down Expand Up @@ -566,13 +546,25 @@ def convert_module_to_trt_engine(
"dla_global_dram_size": dla_global_dram_size,
}

# Decompose the exported program
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(input_list, device)
gm = apply_lowering_passes(gm, torch_inputs)
logger.debug("Lowered Input graph: " + str(gm.graph))

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
try:
interpreter_result = interpret_module_to_result(module, input_list, settings)
interpreter_result = interpret_module_to_result(gm, input_list, settings)
except UnsupportedOperatorException:
logger.error(
f"Conversion of module {module} not currently fully supported or convertible!",
f"Conversion of module {gm} not currently fully supported or convertible!",
exc_info=True,
)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity


class TestConvertMethodToTrtEngine(unittest.TestCase):
class TestConvertModuleToTrtEngine(unittest.TestCase):
def test_convert_module(self):
class Test(torch.nn.Module):
def forward(self, a, b):
Expand All @@ -18,19 +18,21 @@ def forward(self, a, b):

# Create a model
model = Test()
symbolic_traced_gm = torch.fx.symbolic_trace(model)
exp_program = torch.export.export(model, (input_data_0, input_data_1))

# Convert to TensorRT engine
trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine(
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
exp_program, inputs=(input_data_0, input_data_1)
)

# Deserialize the TensorRT engine
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(trt_engine_str)

# Inference on TRT Engine
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
py_trt_module = PythonTorchTensorRTModule(
engine, ["arg0_1", "arg1_1"], ["output0"]
)
trt_output = py_trt_module(input_data_0, input_data_1).cpu()

# Inference on PyTorch model
Expand Down

0 comments on commit 08f1636

Please sign in to comment.