Skip to content

Commit

Permalink
feat(/cpp/api): Working INT8 Calibrator, also resolves #41
Browse files Browse the repository at this point in the history
- Now creates output tensors of the correct type to accept data
- There still may be a data race in the creation of the dataloader
iterator
- Quantization and Dynamic Shape right now don't play well together,
potential subsequent release of TRT may address this

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 24, 2020
1 parent 5f36f47 commit 5c0d737
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ auto results = trt_mod.forward({in_tensor});
> Notes on running in lower precisions:
> - Set precision with extra_info.op_precision
> - The module should be left in FP32 before compilation
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
## Platform Support
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void AddInputs(ConversionCtx* ctx,
"Expected dimension specifications for all input tensors" \
<< ", but found " << input_tensors.size() \
<< " input tensors and " \
<< input_dims.size() << "dimension specs (conversion.AddInputs)");
<< input_dims.size() << " dimension specs (conversion.AddInputs)");

auto profile = ctx->builder->createOptimizationProfile();

Expand Down Expand Up @@ -235,7 +235,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
if (!OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.AddLayer)");
<< " (conversion.VerifyCoverterSupportForBloxk");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
input_type = nvinfer1::DataType::kINT8;
input_type = nvinfer1::DataType::kFLOAT;
// If the calibrator is nullptr then TRT will use default quantization
cfg->setInt8Calibrator(settings.calibrator);
break;
Expand Down
16 changes: 4 additions & 12 deletions core/execution/register_trt_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,11 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
contig_inputs.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
auto expected_type = torch::kF32;
switch (ctx->getEngine().getBindingDataType(i)) {
case nvinfer1::DataType::kHALF:
expected_type = torch::kF16;
break;
case nvinfer1::DataType::kFLOAT:
case nvinfer1::DataType::kINT8:
default:
expected_type = torch::kF32;
}
auto expected_type = util::toATenDType(ctx->getEngine().getBindingDataType(i));
TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
auto shape = core::util::toVec(dims);
contig_inputs.push_back(inputs[i].to(at::kCUDA).view(shape).contiguous());
contig_inputs.push_back(inputs[i].view(shape).contiguous());
LOG_DEBUG("In shape:" << shape);
ctx->setBindingDimensions(i, dims);
gpu_handles.push_back(contig_inputs.back().data_ptr());
Expand All @@ -43,7 +34,8 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
auto out_shape = ctx->getBindingDimensions(o);
//LOG_DEBUG("Output: " << engine->getBindingName(o) << " out shape: " << out_shape);
auto dims = core::util::toVec(out_shape);
outputs.push_back(at::empty(dims, {at::kCUDA}).contiguous());
auto type = util::toATenDType(ctx->getEngine().getBindingDataType(o));
outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous());
gpu_handles.push_back(outputs[outputs.size() - 1].data_ptr());
}

Expand Down
26 changes: 22 additions & 4 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad");
return toDims(l);
}

if (pad_to > nvinfer1::Dims::MAX_DIMS) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
}

nvinfer1::Dims dims;
dims.nbDims = pad_to;
for (size_t i = 0; i < pad_to - l.size(); i++) {
dims.d[i] = 1;
}

for (size_t i = pad_to - l.size(); i < pad_to; i++) {
dims.d[i] = l[i - (pad_to - l.size())];
}
Expand Down Expand Up @@ -58,7 +58,7 @@ nvinfer1::Dims toDims(c10::List<int64_t> l) {
}
return dims;
}

std::vector<int64_t> toVec(nvinfer1::Dims d) {
std::vector<int64_t> dims;
for (int i = 0; i < d.nbDims; i++) {
Expand Down Expand Up @@ -110,8 +110,26 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
};
return at_trt_type_map;
}

const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_map() {
static const std::unordered_map<nvinfer1::DataType, at::ScalarType> trt_at_type_map = {
{nvinfer1::DataType::kFLOAT, at::kFloat},
{nvinfer1::DataType::kHALF, at::kHalf},
{nvinfer1::DataType::kINT32, at::kInt},
{nvinfer1::DataType::kINT8, at::kChar},
};
return trt_at_type_map;
}
} // namespace

const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_aten_type_map() {
return get_trt_at_type_map();
}

at::ScalarType toATenDType(nvinfer1::DataType t) {
return get_trt_aten_type_map().at(t);
}

const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map() {
return get_at_trt_type_map();
}
Expand Down
5 changes: 3 additions & 2 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ inline bool operator==(const nvinfer1::Dims& in1, const nvinfer1::Dims& in2) {
}

// TODO maybe look to support broadcasting comparisons

for (int64_t i = 0; i < in1.nbDims; i++) {
if (in1.d[i] != in2.d[i]) {
return false;
Expand Down Expand Up @@ -85,11 +85,12 @@ nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l);
std::vector<int64_t> toVec(nvinfer1::Dims d);
std::string toStr(nvinfer1::Dims d);

at::ScalarType toATenDType(nvinfer1::DataType t);
nvinfer1::DataType toTRTDataType(at::ScalarType t);
c10::optional<nvinfer1::DataType>toTRTDataType(caffe2::TypeMeta dtype);

const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map();

} // namespace util
} // namespace core
} // namespace trtorch
22 changes: 15 additions & 7 deletions cpp/ptq/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <sys/stat.h>

int main(int argc, const char* argv[]) {
trtorch::logging::set_reportable_log_level(trtorch::logging::kINFO);
trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kERROR);
if (argc < 3) {
std::cerr << "usage: ptq <path-to-module> <path-to-cifar10>\n";
return -1;
Expand Down Expand Up @@ -50,11 +50,13 @@ int main(int argc, const char* argv[]) {
// Configure settings for compilation
auto extra_info = trtorch::ExtraInfo({input_shape});
// Set operating precision to INT8
extra_info.op_precision = torch::kFI8;
extra_info.op_precision = torch::kI8;
// Use the TensorRT Entropy Calibrator
extra_info.ptq_calibrator = calibrator;
// Set max batch size for the engine
extra_info.max_batch_size = 32;
// Set a larger workspace
extra_info.workspace_size = 1 << 28;

mod.eval();

Expand Down Expand Up @@ -82,6 +84,7 @@ int main(int argc, const char* argv[]) {
std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl;

// Compile Graph
std::cout << "Compiling and quantizing module" << std::endl;
auto trt_mod = trtorch::CompileGraph(mod, extra_info);

// Check the INT8 accuracy in TRT
Expand All @@ -91,22 +94,27 @@ int main(int argc, const char* argv[]) {
auto images = batch.data.to(torch::kCUDA);
auto targets = batch.target.to(torch::kCUDA);

if (images.sizes()[0] < 32) {
// To handle smaller batches util Optimization profiles work with Int8
auto diff = 32 - images.sizes()[0];
auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA});
auto target_padding = torch::zeros({diff}, {torch::kCUDA});
images = torch::cat({images, img_padding}, 0);
targets = torch::cat({targets, target_padding}, 0);
}

auto outputs = trt_mod.forward({images});
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
predictions = predictions.reshape(predictions.sizes()[0]);

if (predictions.sizes()[0] != targets.sizes()[0]) {
// To handle smaller batches util Optimization profiles work
// To handle smaller batches util Optimization profiles work with Int8
predictions = predictions.slice(0, 0, targets.sizes()[0]);
}

std:: cout << predictions << targets << std::endl;

total += targets.sizes()[0];
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
std::cout << total << " " << correct << std::endl;
}
std::cout << total << " " << correct << std::endl;
std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl;

// Time execution in INT8
Expand Down

0 comments on commit 5c0d737

Please sign in to comment.