Skip to content

Commit

Permalink
fix(qat): Rescale input data for C++ application
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 12, 2021
1 parent b9b7f63 commit 9dc6061
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/int8/datasets/cifar10.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
labels.push_back(label);
auto image_tensor =
torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8))
.to(torch::kF32);
.to(torch::kF32).div(255);
images.push_back(image_tensor);
}

Expand Down
1 change: 1 addition & 0 deletions examples/int8/ptq/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,5 @@ int main(int argc, const char* argv[]) {

auto trt_runtimes = benchmark_module(trt_mod, dims[0]);
print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]);
trt_mod.save("/tmp/ptq_vgg16.trt.ts");
}
1 change: 1 addition & 0 deletions examples/int8/qat/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,6 @@ int main(int argc, const char* argv[]) {

auto trt_runtimes = benchmark_module(trt_mod, dims[0]);
print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]);
trt_mod.save("/tmp/qat_vgg16.trt.ts");
}

0 comments on commit 9dc6061

Please sign in to comment.