Skip to content

Commit

Permalink
chore: cherry-pick FP8 (#2892)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Jun 7, 2024
1 parent 29272fa commit 7f16bda
Show file tree
Hide file tree
Showing 23 changed files with 491 additions and 44 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-torchscript-fe
Expand Down Expand Up @@ -99,6 +100,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-dynamo-converters
Expand Down Expand Up @@ -126,6 +128,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-dynamo-fe
Expand Down Expand Up @@ -154,6 +157,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-dynamo-serde
Expand Down Expand Up @@ -181,6 +185,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-torch-compile-be
Expand Down Expand Up @@ -210,6 +215,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-dynamo-core
Expand Down Expand Up @@ -238,7 +244,9 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-core
repository: "pytorch/tensorrt"
Expand Down
2 changes: 1 addition & 1 deletion dev_dep_versions.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__: "2.4.0.dev0"
__cuda_version__: "12.1"
__tensorrt_version__: "10.0.1"
__tensorrt_version__: "10.0.1"
2 changes: 1 addition & 1 deletion docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion

tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq

Python API Documenation
------------------------
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
251 changes: 251 additions & 0 deletions examples/dynamo/vgg16_fp8_ptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
"""
.. _vgg16_fp8_ptq:
Torch Compile VGG16 with FP8 and PTQ
======================================================
This script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a VGG16 model with FP8 and PTQ.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import argparse

import modelopt.torch.quantization as mtq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from modelopt.torch.quantization.utils import export_torch_mode


class VGG(nn.Module):
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
super(VGG, self).__init__()

layers = []
in_channels = 3
for l in layer_spec:
if l == "pool":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
layers += [
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
nn.BatchNorm2d(l),
nn.ReLU(),
]
in_channels = l

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Linear(512 * 1 * 1, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x


def vgg16(num_classes=1000, init_weights=False):
vgg16_cfg = [
64,
64,
"pool",
128,
128,
"pool",
256,
256,
256,
"pool",
512,
512,
512,
"pool",
512,
512,
512,
"pool",
]
return VGG(vgg16_cfg, num_classes, init_weights)


PARSER = argparse.ArgumentParser(
description="Load pre-trained VGG model and then tune with FP8 and PTQ"
)
PARSER.add_argument(
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
)
PARSER.add_argument(
"--batch-size",
default=128,
type=int,
help="Batch size for tuning the model with PTQ and FP8",
)

args = PARSER.parse_args()

model = vgg16(num_classes=10, init_weights=False)
model = model.cuda()

# %%
# Load the pre-trained model weights
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

ckpt = torch.load(args.ckpt)
weights = ckpt["model_state_dict"]

if torch.cuda.device_count() > 1:
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in weights.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
weights = new_state_dict

model.load_state_dict(weights)
# Don't forget to set the model to evaluation mode!
model.eval()

# %%
# Load training dataset and define loss function for PTQ
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

training_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)
training_dataloader = torch.utils.data.DataLoader(
training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
)

data = iter(training_dataloader)
images, _ = next(data)

crit = nn.CrossEntropyLoss()

# %%
# Define Calibration Loop for quantization
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


def calibrate_loop(model):
# calibrate over the training dataset
total = 0
correct = 0
loss = 0.0
for data, labels in training_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
total += labels.size(0)
correct += (preds == labels).sum().item()

print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))


# %%
# Tune the pre-trained model with FP8 and PTQ
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

quant_cfg = mtq.FP8_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Load the testing dataset
testing_dataset = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)

testing_dataloader = torch.utils.data.DataLoader(
testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
)

with torch.no_grad():
with export_torch_mode():
# Compile the model with Torch-TensorRT Dynamo backend
input_tensor = images.cuda()
exp_program = torch.export.export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.float8_e4m3fn},
min_block_size=1,
debug=False,
)

# Inference compiled Torch-TensorRT model over the testing dataset
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []
model.eval()
for data, labels in testing_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()

test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
test_preds = torch.cat(class_preds)
test_loss = loss / total
test_acc = correct / total
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
2 changes: 2 additions & 0 deletions examples/int8/training/vgg16/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ nvidia-pyindex
--extra-index-url https://pypi.nvidia.com
pytorch-quantization
tqdm
nvidia-modelopt
--extra-index-url https://pypi.nvidia.com
1 change: 0 additions & 1 deletion packaging/pre_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Install dependencies
python3 -m pip install pyyaml
yum install -y ninja-build gettext
TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()")
wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \
&& mv bazelisk-linux-amd64 /usr/bin/bazel \
&& chmod +x /usr/bin/bazel
Expand Down
Loading

0 comments on commit 7f16bda

Please sign in to comment.