diff --git a/examples/int8/datasets/cifar10.cpp b/examples/int8/datasets/cifar10.cpp index 10c52973f8..dff4716dac 100644 --- a/examples/int8/datasets/cifar10.cpp +++ b/examples/int8/datasets/cifar10.cpp @@ -50,7 +50,7 @@ std::pair read_batch(const std::string& path) { labels.push_back(label); auto image_tensor = torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8)) - .to(torch::kF32); + .to(torch::kF32).div(255); images.push_back(image_tensor); } diff --git a/examples/int8/ptq/main.cpp b/examples/int8/ptq/main.cpp index eb96adb98f..752d3a84fe 100644 --- a/examples/int8/ptq/main.cpp +++ b/examples/int8/ptq/main.cpp @@ -140,4 +140,5 @@ int main(int argc, const char* argv[]) { auto trt_runtimes = benchmark_module(trt_mod, dims[0]); print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); + trt_mod.save("/tmp/ptq_vgg16.trt.ts"); } diff --git a/examples/int8/qat/main.cpp b/examples/int8/qat/main.cpp index 33e5e295bb..50db43ec1e 100644 --- a/examples/int8/qat/main.cpp +++ b/examples/int8/qat/main.cpp @@ -124,5 +124,6 @@ int main(int argc, const char* argv[]) { auto trt_runtimes = benchmark_module(trt_mod, dims[0]); print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); + trt_mod.save("/tmp/qat_vgg16.trt.ts"); }