diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index f3196bbb3a..523befea1c 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -65,6 +65,7 @@ jobs: package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main with: job-name: tests-py-torchscript-fe @@ -99,6 +100,7 @@ jobs: package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main with: job-name: tests-py-dynamo-converters @@ -126,6 +128,7 @@ jobs: package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main with: job-name: tests-py-dynamo-fe @@ -154,6 +157,7 @@ jobs: package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main with: job-name: tests-py-dynamo-serde @@ -181,6 +185,7 @@ jobs: package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main with: job-name: tests-py-torch-compile-be @@ -210,6 +215,7 @@ jobs: package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main with: job-name: tests-py-dynamo-core @@ -238,7 +244,9 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-core repository: "pytorch/tensorrt" diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 1c0e24ade4..ec7f75f599 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,3 +1,3 @@ __version__: "2.4.0.dev0" __cuda_version__: "12.1" -__tensorrt_version__: "10.0.1" +__tensorrt_version__: "10.0.1" \ No newline at end of file diff --git a/docsrc/index.rst b/docsrc/index.rst index cc586a053e..df3f297162 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -114,7 +114,7 @@ Tutorials tutorials/_rendered_examples/dynamo/custom_kernel_plugins tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion - + tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq Python API Documenation ------------------------ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index 7191c02fa0..bda997b96b 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -11,3 +11,4 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API * :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines +* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` diff --git a/examples/dynamo/vgg16_fp8_ptq.py b/examples/dynamo/vgg16_fp8_ptq.py new file mode 100644 index 0000000000..b2a82cc4f8 --- /dev/null +++ b/examples/dynamo/vgg16_fp8_ptq.py @@ -0,0 +1,251 @@ +""" +.. _vgg16_fp8_ptq: + +Torch Compile VGG16 with FP8 and PTQ +====================================================== + +This script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a VGG16 model with FP8 and PTQ. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import argparse + +import modelopt.torch.quantization as mtq +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from modelopt.torch.quantization.utils import export_torch_mode + + +class VGG(nn.Module): + def __init__(self, layer_spec, num_classes=1000, init_weights=False): + super(VGG, self).__init__() + + layers = [] + in_channels = 3 + for l in layer_spec: + if l == "pool": + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + else: + layers += [ + nn.Conv2d(in_channels, l, kernel_size=3, padding=1), + nn.BatchNorm2d(l), + nn.ReLU(), + ] + in_channels = l + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.classifier = nn.Sequential( + nn.Linear(512 * 1 * 1, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + if init_weights: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +def vgg16(num_classes=1000, init_weights=False): + vgg16_cfg = [ + 64, + 64, + "pool", + 128, + 128, + "pool", + 256, + 256, + 256, + "pool", + 512, + 512, + 512, + "pool", + 512, + 512, + 512, + "pool", + ] + return VGG(vgg16_cfg, num_classes, init_weights) + + +PARSER = argparse.ArgumentParser( + description="Load pre-trained VGG model and then tune with FP8 and PTQ" +) +PARSER.add_argument( + "--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint" +) +PARSER.add_argument( + "--batch-size", + default=128, + type=int, + help="Batch size for tuning the model with PTQ and FP8", +) + +args = PARSER.parse_args() + +model = vgg16(num_classes=10, init_weights=False) +model = model.cuda() + +# %% +# Load the pre-trained model weights +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +ckpt = torch.load(args.ckpt) +weights = ckpt["model_state_dict"] + +if torch.cuda.device_count() > 1: + from collections import OrderedDict + + new_state_dict = OrderedDict() + for k, v in weights.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + weights = new_state_dict + +model.load_state_dict(weights) +# Don't forget to set the model to evaluation mode! +model.eval() + +# %% +# Load training dataset and define loss function for PTQ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +training_dataset = datasets.CIFAR10( + root="./data", + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ), +) +training_dataloader = torch.utils.data.DataLoader( + training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2 +) + +data = iter(training_dataloader) +images, _ = next(data) + +crit = nn.CrossEntropyLoss() + +# %% +# Define Calibration Loop for quantization +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +def calibrate_loop(model): + # calibrate over the training dataset + total = 0 + correct = 0 + loss = 0.0 + for data, labels in training_dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + total += labels.size(0) + correct += (preds == labels).sum().item() + + print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total)) + + +# %% +# Tune the pre-trained model with FP8 and PTQ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +quant_cfg = mtq.FP8_DEFAULT_CFG +# PTQ with in-place replacement to quantized modules +mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) +# model has FP8 qdq nodes at this point + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Load the testing dataset +testing_dataset = datasets.CIFAR10( + root="./data", + train=False, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ), +) + +testing_dataloader = torch.utils.data.DataLoader( + testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2 +) + +with torch.no_grad(): + with export_torch_mode(): + # Compile the model with Torch-TensorRT Dynamo backend + input_tensor = images.cuda() + exp_program = torch.export.export(model, (input_tensor,)) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + debug=False, + ) + + # Inference compiled Torch-TensorRT model over the testing dataset + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + model.eval() + for data, labels in testing_dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + test_loss = loss / total + test_acc = correct / total + print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) diff --git a/examples/int8/training/vgg16/requirements.txt b/examples/int8/training/vgg16/requirements.txt index d02af2c616..3b0b03f5d7 100644 --- a/examples/int8/training/vgg16/requirements.txt +++ b/examples/int8/training/vgg16/requirements.txt @@ -4,3 +4,5 @@ nvidia-pyindex --extra-index-url https://pypi.nvidia.com pytorch-quantization tqdm +nvidia-modelopt +--extra-index-url https://pypi.nvidia.com diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 044fbabc1f..9be5931b59 100755 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -3,7 +3,6 @@ # Install dependencies python3 -m pip install pyyaml yum install -y ninja-build gettext -TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \ && mv bazelisk-linux-amd64 /usr/bin/bazel \ && chmod +x /usr/bin/bazel diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index ed0b00a109..befc22d474 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -5,11 +5,10 @@ from typing import Any, Optional, Type, Union import numpy as np +import tensorrt as trt import torch from torch_tensorrt._features import ENABLED_FEATURES -import tensorrt as trt - class dtype(Enum): """Enum to set supported dtypes in the compiler""" @@ -24,9 +23,9 @@ class dtype(Enum): f32 = auto() f64 = auto() b = auto() + bf16 = auto() - # TODO: Enable FP8 - # f8 = auto() + f8 = auto() uint8 = u8 int8 = i8 @@ -36,6 +35,9 @@ class dtype(Enum): long = i64 int64 = i64 + float8 = f8 + fp8 = f8 + half = f16 fp16 = f16 float16 = f16 @@ -48,10 +50,6 @@ class dtype(Enum): fp64 = f64 float64 = f64 - # TODO: Enable when FP8 is enabled - # float8 = f8 - # fp8 = f8 - bfloat16 = bf16 @staticmethod @@ -79,6 +77,8 @@ def _from( return dtype.i64 elif t == torch.int32: return dtype.i32 + elif t == torch.float8_e4m3fn: + return dtype.f8 elif t == torch.half: return dtype.f16 elif t == torch.float: @@ -103,6 +103,8 @@ def _from( return dtype.u8 elif t == trt.DataType.INT8: return dtype.i8 + elif t == trt.DataType.FP8: + return dtype.f8 elif t == trt.DataType.INT32: return dtype.i32 elif t == trt.DataType.INT64: @@ -210,6 +212,8 @@ def to( return torch.int elif self == dtype.i64: return torch.long + elif self == dtype.f8: + return torch.float8_e4m3fn elif self == dtype.f16: return torch.half elif self == dtype.f32: @@ -235,6 +239,8 @@ def to( return trt.DataType.INT8 elif self == dtype.i32: return trt.DataType.INT32 + elif self == dtype.f8: + return trt.DataType.FP8 elif self == dtype.i64: return trt.DataType.INT64 elif self == dtype.f16: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c3cca50f65..a4cd703422 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -28,7 +28,11 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) from torch_tensorrt.dynamo.utils import ( get_torch_inputs, parse_complex_tensor_structs, @@ -167,6 +171,7 @@ def compile( # Prepare torch_trt inputs inputs = prepare_inputs(inputs) + torch_inputs = get_torch_inputs(inputs, device) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -174,15 +179,14 @@ def compile( raise AssertionError( f"Input graph should be an ExportedProgram but got type {type(exported_program)}" ) + exported_program = pre_export_lowering(exported_program, torch_inputs) exported_program = exported_program.run_decompositions( get_decompositions(enable_experimental_decompositions) ) gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module - torch_inputs = get_torch_inputs(inputs, device) - gm = apply_lowering_passes(gm, torch_inputs) - + gm = post_lowering(gm, torch_inputs) logger.debug("Lowered Input graph: " + str(gm.graph)) compilation_options = { @@ -553,7 +557,7 @@ def convert_module_to_trt_engine( # Prepare torch_trt inputs input_list = prepare_inputs(input_list) device = to_torch_tensorrt_device(device) - + torch_inputs = get_torch_inputs(input_list, device) enabled_precisions = {dtype._from(e) for e in enabled_precisions} compilation_options = { @@ -583,6 +587,7 @@ def convert_module_to_trt_engine( "dla_global_dram_size": dla_global_dram_size, } + exported_program = pre_export_lowering(exported_program, torch_inputs) # Decompose the exported program exported_program = exported_program.run_decompositions( get_decompositions(enable_experimental_decompositions) @@ -591,8 +596,7 @@ def convert_module_to_trt_engine( logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module - torch_inputs = get_torch_inputs(input_list, device) - gm = apply_lowering_passes(gm, torch_inputs) + gm = post_lowering(gm, torch_inputs) logger.debug("Lowered Input graph: " + str(gm.graph)) settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index c6612617c8..3bfda9a8b7 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -27,7 +27,7 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16} +SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index dbb900009a..da9da6e02c 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -11,8 +11,9 @@ from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( - apply_lowering_passes, get_decompositions, + post_lowering, + remove_detach, remove_sym_nodes, repair_input_aliasing, ) @@ -82,6 +83,9 @@ def _pretraced_backend( input for input in sample_inputs if isinstance(input, torch.Tensor) ] + # Remove detach nodes + remove_detach(gm, torch_inputs) + # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, @@ -94,7 +98,7 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - gm = apply_lowering_passes(gm, torch_inputs) + gm = post_lowering(gm, sample_inputs) logger.debug("Lowered Input graph:\n " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index b37dddc4c1..3d2d661446 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -106,8 +106,6 @@ def __init__( [dtype._from(o) for o in output_dtypes] if output_dtypes else None ) - _LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}") - def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -243,6 +241,9 @@ def _populate_trt_builder_config( if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) + if dtype.fp8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP8) + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.BF16) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index cb90697921..d37beb3db6 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -580,6 +580,34 @@ def aten_ops_neg( ) +try: + import modelopt.torch.quantization as mtq + + assert torch.ops.trt.quantize_fp8.default +except Exception as e: + _LOGGER.warning( + "Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models" + ) +else: + + @dynamo_tensorrt_converter(torch.ops.trt.quantize_fp8.default) + def aten_ops_quantize_fp8( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.quantize.quantize_fp8( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index ca71cb0b0c..a18155d6be 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -18,6 +18,7 @@ pad, permutation, pool, + quantize, reduce, select, shape, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py new file mode 100644 index 0000000000..de78385ce9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -0,0 +1,45 @@ +from typing import Optional + +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def quantize_fp8( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + scale: np.ndarray, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based + on the output_type set and dequantizes them back. + """ + if (isinstance(input_tensor, TRTTensor)) and not ( + input_tensor.dtype == trt.float32 or input_tensor.dtype == trt.float16 + ): + raise ValueError( + f"quantize_fp8 converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" + ) + + scale = get_trt_tensor(ctx, scale, name + "_scale") + # Add Q node + quantize_layer = ctx.net.add_quantize(input_tensor, scale) + quantize_layer.set_output_type(0, trt.DataType.FP8) + set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + q_output = quantize_layer.get_output(0) + # Add DQ node + dequantize_layer = ctx.net.add_dequantize(q_output, scale) + set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) + # Set DQ layer precision to FP8 + dequantize_layer.precision = trt.DataType.FP8 + dq_output = dequantize_layer.get_output(0) + + return dq_output diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index a89780ded4..83c85855fe 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -5,4 +5,5 @@ from ._decompositions import get_decompositions # noqa: F401 from ._remove_sym_nodes import remove_sym_nodes from ._repair_input_aliasing import repair_input_aliasing -from .passes import apply_lowering_passes +from .passes import post_lowering, pre_export_lowering +from .passes.remove_detach import remove_detach diff --git a/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py b/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py index e85117a423..8adebc87f8 100644 --- a/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py @@ -10,17 +10,18 @@ def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: dynamic=True behavior """ # Extract SymInt placeholder Tensors - placeholders = [ + placeholder_sym_ints = [ node for node in gm.graph.nodes if ( node.op == "placeholder" and isinstance(node.type, type) and issubclass(node.type, torch.SymInt) + and not node.users ) ] - for node in placeholders: + for node in placeholder_sym_ints: gm.graph.erase_node(node) gm.graph.lint() diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 489805cb43..3d1663fe0b 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -8,12 +8,13 @@ from .lower_linear import lower_linear from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager +from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape -ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ remove_input_alias_fixing_clones, constant_fold, @@ -26,6 +27,12 @@ ] ) +ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + remove_detach, + ] +) + logger = logging.getLogger(__name__) @@ -48,9 +55,9 @@ def _aten_lowering_pass( def add_lowering_pass( lowering_pass: LoweringPassSignature, ) -> LoweringPassSignature: - ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return lowering_pass @@ -72,23 +79,35 @@ def add_lowering_pass( def _remove_lowering_pass(*, index: int) -> None: """Removes a lowering pass at a specific index from the registry""" - ATEN_LOWERING_PASSES.remove_pass_with_index(index) + ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index) logger.debug( - f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return -def apply_lowering_passes( +def post_lowering( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: - """Applies the lowering passes to a graph module, returns the modified GraphModule""" + """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}" + ) + return ATEN_POST_LOWERING_PASSES(gm, sample_inputs) + + +def pre_export_lowering( + ep: torch.export.ExportedProgram, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" logging.debug( - f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}" ) - return ATEN_LOWERING_PASSES(gm, sample_inputs) + gm = ep.graph_module + gm = ATEN_PRE_LOWERING_PASSES(gm, sample_inputs) + return ep def dump_lowering_passes() -> str: """Returns a string containing the lowering passes""" - return str(ATEN_LOWERING_PASSES) + return str(ATEN_POST_LOWERING_PASSES) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py new file mode 100644 index 0000000000..5f1ab3738d --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py @@ -0,0 +1,25 @@ +import logging +from typing import Sequence + +import torch + +logger = logging.getLogger(__name__) + + +def remove_detach( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Remove detach ops in the graph""" + count = 0 + for node in gm.graph.nodes: + # node.target = "detach" in torch.compile workflow + if node.target == torch.ops.aten.detach.default or node.target == "detach": + # Detach node has only one input + node_input = node.all_input_nodes[0] + node.replace_all_uses_with(node_input) + gm.graph.erase_node(node) + count += 1 + + logger.debug(f"Removed {count} detach nodes:\n{gm.graph}") + + return gm diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 28673f63df..826d74ac96 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -15,7 +15,11 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs @@ -207,14 +211,16 @@ def generate_graph( torch_inputs = get_torch_inputs(original_inputs, _defaults.DEVICE) if use_dynamo_tracer: exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) + exported_program = pre_export_lowering(exported_program, torch_inputs) exported_program = exported_program.run_decompositions( get_decompositions(False) ) fx_module = exported_program.module() else: fx_module = torch.fx.symbolic_trace(mod) + if enable_passes: - fx_module = apply_lowering_passes(fx_module, torch_inputs) + fx_module = post_lowering(fx_module, original_inputs) if propagate_shapes: # TODO: This is currently being used to test embedding_bag_aten due to https://github.com/pytorch/TensorRT/issues/2843 diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 8a6b2fb726..9fdab1a9d0 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -4,13 +4,12 @@ import pytest import timm import torch +import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity from transformers import BertModel from transformers.utils.fx import symbolic_trace as transformers_trace -import torch_tensorrt as torchtrt - assertions = unittest.TestCase() @@ -183,3 +182,50 @@ def test_resnet18_half(ir): # Clean up model env torch._dynamo.reset() + + +@unittest.skipIf( + torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, + "FP8 compilation in Torch-TRT is not supported on cards older than Hopper", +) +@pytest.mark.unit +def test_base_fp8(ir): + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,)) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + debug=True, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2) diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 742b9fc1a3..e0112ae523 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -8,8 +8,8 @@ from torch._functorch.aot_autograd import aot_export_joint_simple from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo.lowering import ( - apply_lowering_passes, get_decompositions, + post_lowering, repair_input_aliasing, ) @@ -50,7 +50,7 @@ def fx_dynamo_testing_backend( decompositions=get_decompositions(), ) - gm = apply_lowering_passes(gm, sample_inputs) + gm = post_lowering(gm, sample_inputs) trt_compiled = custom_backend( gm, diff --git a/versions.py b/versions.py index 81dbe72794..9936e9b225 100644 --- a/versions.py +++ b/versions.py @@ -10,7 +10,6 @@ __cuda_version__ = "0.0" __tensorrt_version__ = "0.0" - LEADING_V_PATTERN = re.compile("^v") TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$") LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")