From 9367f0d529b675983e52bb4edb8099e309512d2d Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 24 Jun 2024 14:08:48 -0700 Subject: [PATCH 1/3] fix typo --- examples/dynamo/vgg16_fp8_ptq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dynamo/vgg16_fp8_ptq.py b/examples/dynamo/vgg16_fp8_ptq.py index e71152935e..6a151de26a 100644 --- a/examples/dynamo/vgg16_fp8_ptq.py +++ b/examples/dynamo/vgg16_fp8_ptq.py @@ -235,10 +235,10 @@ def calibrate_loop(model): loss = 0.0 class_probs = [] class_preds = [] - model.eval() + trt_model.eval() for data, labels in testing_dataloader: data, labels = data.cuda(), labels.cuda(non_blocking=True) - out = model(data) + out = trt_model(data) loss += crit(out, labels) preds = torch.max(out, 1)[1] class_probs.append([F.softmax(i, dim=0) for i in out]) From 8a40219fdc41f8ca8d09f3edab6d04fdbbcc279a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 24 Jun 2024 16:03:36 -0700 Subject: [PATCH 2/3] drop the last incomplete batch --- examples/dynamo/vgg16_fp8_ptq.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/dynamo/vgg16_fp8_ptq.py b/examples/dynamo/vgg16_fp8_ptq.py index 6a151de26a..5f3b80f470 100644 --- a/examples/dynamo/vgg16_fp8_ptq.py +++ b/examples/dynamo/vgg16_fp8_ptq.py @@ -155,7 +155,11 @@ def vgg16(num_classes=1000, init_weights=False): ), ) training_dataloader = torch.utils.data.DataLoader( - training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2 + training_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=2, + drop_last=True, ) data = iter(training_dataloader) @@ -211,8 +215,12 @@ def calibrate_loop(model): ) testing_dataloader = torch.utils.data.DataLoader( - testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2 -) + testing_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=2, + drop_last=True, +) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()` with torch.no_grad(): with export_torch_mode(): From 2f85858a71d0d6815be679a16c75c4d9ce484cd4 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 24 Jun 2024 16:42:28 -0700 Subject: [PATCH 3/3] remove eval() --- examples/dynamo/vgg16_fp8_ptq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dynamo/vgg16_fp8_ptq.py b/examples/dynamo/vgg16_fp8_ptq.py index 5f3b80f470..21c44efcee 100644 --- a/examples/dynamo/vgg16_fp8_ptq.py +++ b/examples/dynamo/vgg16_fp8_ptq.py @@ -243,7 +243,6 @@ def calibrate_loop(model): loss = 0.0 class_probs = [] class_preds = [] - trt_model.eval() for data, labels in testing_dataloader: data, labels = data.cuda(), labels.cuda(non_blocking=True) out = trt_model(data)