From 5c0d7373414c283d4b981f38d1dd2561048c5d8e Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 23 Apr 2020 17:00:13 -0700 Subject: [PATCH] feat(/cpp/api): Working INT8 Calibrator, also resolves #41 - Now creates output tensors of the correct type to accept data - There still may be a data race in the creation of the dataloader iterator - Quantization and Dynamic Shape right now don't play well together, potential subsequent release of TRT may address this Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- README.md | 2 +- core/conversion/conversion.cpp | 4 +-- .../conversionctx/ConversionCtx.cpp | 2 +- core/execution/register_trt_op.cpp | 16 +++--------- core/util/trt_util.cpp | 26 ++++++++++++++++--- core/util/trt_util.h | 5 ++-- cpp/ptq/main.cpp | 22 +++++++++++----- 7 files changed, 48 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 213abc138a..fbb7fff688 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ auto results = trt_mod.forward({in_tensor}); > Notes on running in lower precisions: > - Set precision with extra_info.op_precision -> - The module should be left in FP32 before compilation +> - The module should be left in FP32 before compilation (FP16 can support half tensor models) > - In FP16 only input tensors should be converted to FP16, other precisions use FP32 ## Platform Support diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index d71af6dbdc..a13e7b2865 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -133,7 +133,7 @@ void AddInputs(ConversionCtx* ctx, "Expected dimension specifications for all input tensors" \ << ", but found " << input_tensors.size() \ << " input tensors and " \ - << input_dims.size() << "dimension specs (conversion.AddInputs)"); + << input_dims.size() << " dimension specs (conversion.AddInputs)"); auto profile = ctx->builder->createOptimizationProfile(); @@ -235,7 +235,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) { if (!OpSupported(n)) { auto schema = n->maybeSchema(); TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \ - << " (conversion.AddLayer)"); + << " (conversion.VerifyCoverterSupportForBloxk"); std::stringstream ss; ss << *schema; unsupported_ops.insert(ss.str()); diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index ad5668a19a..2d2e321a83 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -51,7 +51,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) case nvinfer1::DataType::kINT8: TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8"); cfg->setFlag(nvinfer1::BuilderFlag::kINT8); - input_type = nvinfer1::DataType::kINT8; + input_type = nvinfer1::DataType::kFLOAT; // If the calibrator is nullptr then TRT will use default quantization cfg->setInt8Calibrator(settings.calibrator); break; diff --git a/core/execution/register_trt_op.cpp b/core/execution/register_trt_op.cpp index 460aad51c5..495c01c0e6 100644 --- a/core/execution/register_trt_op.cpp +++ b/core/execution/register_trt_op.cpp @@ -17,20 +17,11 @@ std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai contig_inputs.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - auto expected_type = torch::kF32; - switch (ctx->getEngine().getBindingDataType(i)) { - case nvinfer1::DataType::kHALF: - expected_type = torch::kF16; - break; - case nvinfer1::DataType::kFLOAT: - case nvinfer1::DataType::kINT8: - default: - expected_type = torch::kF32; - } + auto expected_type = util::toATenDType(ctx->getEngine().getBindingDataType(i)); TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); auto dims = core::util::toDimsPad(inputs[i].sizes(), 1); auto shape = core::util::toVec(dims); - contig_inputs.push_back(inputs[i].to(at::kCUDA).view(shape).contiguous()); + contig_inputs.push_back(inputs[i].view(shape).contiguous()); LOG_DEBUG("In shape:" << shape); ctx->setBindingDimensions(i, dims); gpu_handles.push_back(contig_inputs.back().data_ptr()); @@ -43,7 +34,8 @@ std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai auto out_shape = ctx->getBindingDimensions(o); //LOG_DEBUG("Output: " << engine->getBindingName(o) << " out shape: " << out_shape); auto dims = core::util::toVec(out_shape); - outputs.push_back(at::empty(dims, {at::kCUDA}).contiguous()); + auto type = util::toATenDType(ctx->getEngine().getBindingDataType(o)); + outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous()); gpu_handles.push_back(outputs[outputs.size() - 1].data_ptr()); } diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 22424e408b..89214e5efd 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -15,18 +15,18 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) { LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad"); return toDims(l); } - + if (pad_to > nvinfer1::Dims::MAX_DIMS) { //TODO: Handle this with exceptions or whatever LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); } - + nvinfer1::Dims dims; dims.nbDims = pad_to; for (size_t i = 0; i < pad_to - l.size(); i++) { dims.d[i] = 1; } - + for (size_t i = pad_to - l.size(); i < pad_to; i++) { dims.d[i] = l[i - (pad_to - l.size())]; } @@ -58,7 +58,7 @@ nvinfer1::Dims toDims(c10::List l) { } return dims; } - + std::vector toVec(nvinfer1::Dims d) { std::vector dims; for (int i = 0; i < d.nbDims; i++) { @@ -110,8 +110,26 @@ const std::unordered_map& get_at_trt_type_ma }; return at_trt_type_map; } + +const std::unordered_map& get_trt_at_type_map() { + static const std::unordered_map trt_at_type_map = { + {nvinfer1::DataType::kFLOAT, at::kFloat}, + {nvinfer1::DataType::kHALF, at::kHalf}, + {nvinfer1::DataType::kINT32, at::kInt}, + {nvinfer1::DataType::kINT8, at::kChar}, + }; + return trt_at_type_map; +} } // namespace +const std::unordered_map& get_trt_aten_type_map() { + return get_trt_at_type_map(); +} + +at::ScalarType toATenDType(nvinfer1::DataType t) { + return get_trt_aten_type_map().at(t); +} + const std::unordered_map& get_aten_trt_type_map() { return get_at_trt_type_map(); } diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 50225fbdc5..bf8ea5b224 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -21,7 +21,7 @@ inline bool operator==(const nvinfer1::Dims& in1, const nvinfer1::Dims& in2) { } // TODO maybe look to support broadcasting comparisons - + for (int64_t i = 0; i < in1.nbDims; i++) { if (in1.d[i] != in2.d[i]) { return false; @@ -85,11 +85,12 @@ nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l); std::vector toVec(nvinfer1::Dims d); std::string toStr(nvinfer1::Dims d); +at::ScalarType toATenDType(nvinfer1::DataType t); nvinfer1::DataType toTRTDataType(at::ScalarType t); c10::optionaltoTRTDataType(caffe2::TypeMeta dtype); const std::unordered_map& get_aten_trt_type_map(); - + } // namespace util } // namespace core } // namespace trtorch diff --git a/cpp/ptq/main.cpp b/cpp/ptq/main.cpp index 61eb856c98..c726e70444 100644 --- a/cpp/ptq/main.cpp +++ b/cpp/ptq/main.cpp @@ -13,7 +13,7 @@ #include int main(int argc, const char* argv[]) { - trtorch::logging::set_reportable_log_level(trtorch::logging::kINFO); + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kERROR); if (argc < 3) { std::cerr << "usage: ptq \n"; return -1; @@ -50,11 +50,13 @@ int main(int argc, const char* argv[]) { // Configure settings for compilation auto extra_info = trtorch::ExtraInfo({input_shape}); // Set operating precision to INT8 - extra_info.op_precision = torch::kFI8; + extra_info.op_precision = torch::kI8; // Use the TensorRT Entropy Calibrator extra_info.ptq_calibrator = calibrator; // Set max batch size for the engine extra_info.max_batch_size = 32; + // Set a larger workspace + extra_info.workspace_size = 1 << 28; mod.eval(); @@ -82,6 +84,7 @@ int main(int argc, const char* argv[]) { std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl; // Compile Graph + std::cout << "Compiling and quantizing module" << std::endl; auto trt_mod = trtorch::CompileGraph(mod, extra_info); // Check the INT8 accuracy in TRT @@ -91,22 +94,27 @@ int main(int argc, const char* argv[]) { auto images = batch.data.to(torch::kCUDA); auto targets = batch.target.to(torch::kCUDA); + if (images.sizes()[0] < 32) { + // To handle smaller batches util Optimization profiles work with Int8 + auto diff = 32 - images.sizes()[0]; + auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA}); + auto target_padding = torch::zeros({diff}, {torch::kCUDA}); + images = torch::cat({images, img_padding}, 0); + targets = torch::cat({targets, target_padding}, 0); + } + auto outputs = trt_mod.forward({images}); auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); predictions = predictions.reshape(predictions.sizes()[0]); if (predictions.sizes()[0] != targets.sizes()[0]) { - // To handle smaller batches util Optimization profiles work + // To handle smaller batches util Optimization profiles work with Int8 predictions = predictions.slice(0, 0, targets.sizes()[0]); } - std:: cout << predictions << targets << std::endl; - total += targets.sizes()[0]; correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); - std::cout << total << " " << correct << std::endl; } - std::cout << total << " " << correct << std::endl; std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl; // Time execution in INT8