Skip to content

Commit

Permalink
fix: Error caused by invalid binding name in TRTEngine.to_str() met…
Browse files Browse the repository at this point in the history
…hod (#1846)
  • Loading branch information
gs-olive authored and bowang007 committed Apr 28, 2023
1 parent 77b0d7b commit 47df4cd
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ TRTEngine::TRTEngine(
TORCHTRT_CHECK(
(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
"Binding " << binding_name << " specified as input but found as output in TensorRT engine");
LOG_DEBUG("Input binding name: " << binding_name << "pyt arg idx: " << pyt_idx << ")");
LOG_DEBUG(
"Input binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
<< ", Torch binding index: " << pyt_idx);
in_binding_map[trt_idx] = pyt_idx;
in_binding_names[pyt_idx] = _in_binding_names[pyt_idx];
in_binding_names[pyt_idx] = binding_name;
}

uint64_t outputs = _out_binding_names.size();
Expand Down Expand Up @@ -210,19 +212,21 @@ std::string TRTEngine::to_str() const {
ss << " Inputs: [" << std::endl;
for (uint64_t i = 0; i < num_io.first; i++) {
ss << " id: " << i << std::endl;
ss << " shape: " << exec_ctx->getTensorShape(std::string("input_" + str(i)).c_str()) << std::endl;
ss << " name: " << in_binding_names[i].c_str() << std::endl;
ss << " shape: " << exec_ctx->getTensorShape(in_binding_names[i].c_str()) << std::endl;
ss << " dtype: "
<< util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(std::string("input_" + str(i)).c_str()))
<< util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(in_binding_names[i].c_str()))
<< std::endl;
}
ss << " ]" << std::endl;
ss << " Outputs: [" << std::endl;
for (uint64_t o = 0; o < num_io.second; o++) {
ss << " id: " << o << std::endl;
ss << " shape: " << exec_ctx->getTensorShape(std::string("output_" + str(o)).c_str()) << std::endl;
ss << " name: " << out_binding_names[o].c_str() << std::endl;
ss << " shape: " << exec_ctx->getTensorShape(out_binding_names[o].c_str()) << std::endl;
ss << " dtype: "
<< util::TRTDataTypeToScalarType(
exec_ctx->getEngine().getTensorDataType(std::string("output_" + str(o)).c_str()))
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
<< std::endl;
}
ss << " }" << std::endl;
Expand Down

0 comments on commit 47df4cd

Please sign in to comment.