Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA graph support for TRT EP #16081

Merged
merged 18 commits into from
Jun 21, 2023
2 changes: 2 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,8 @@ if (onnxruntime_USE_TENSORRT)
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.cc"
)

source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT
};
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_min_shapes = info.profile_min_shapes;
profile_max_shapes = info.profile_max_shapes;
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -842,6 +843,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes);
profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes);
profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes);

const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable);
if (!cuda_graph_enable_env.empty()) {
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}
} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -895,6 +901,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty();
}

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
if (cuda_graph_enable_) {
cuda_graph_ = std::make_unique<CUDAGraph>();
}
#endif

/*
* Parse explicit min/max/opt profile shapes from provider options.
*
Expand Down Expand Up @@ -968,7 +980,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_tactic_sources: " << tactic_sources_
<< ", trt_profile_min_shapes: " << profile_min_shapes
<< ", trt_profile_max_shapes: " << profile_max_shapes
<< ", trt_profile_opt_shapes: " << profile_opt_shapes;
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand All @@ -982,6 +995,45 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list);
}

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const {
return cuda_graph_enable_;
}

bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
}

void TensorrtExecutionProvider::CaptureBegin() {
cuda_graph_->Reset();
cuda_graph_->CaptureBegin();
}

void TensorrtExecutionProvider::CaptureEnd() {
cuda_graph_->CaptureEnd();
is_graph_captured_ = true;
}

bool TensorrtExecutionProvider::IsGraphCaptured() const {
return is_graph_captured_;
}

Status TensorrtExecutionProvider::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// Please note that CUDAGraph::Replay() is not thread safe.
// ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(),
// therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe.
return cuda_graph_->Replay();
}

void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
// Please note that this function is not thread safe.
// ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(),
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// therefore following increment is guaranteed to be thread safe.
++regular_run_count_before_graph_capture_;
}
#endif

AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeDefault) {
return allocator_;
Expand Down Expand Up @@ -1063,6 +1115,10 @@ std::unique_ptr<IDataTransfer> TensorrtExecutionProvider::GetDataTransfer() cons
return onnxruntime::CreateGPUDataTransfer();
}

Status TensorrtExecutionProvider::OnRunStart() {
return Status::OK();
}

Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
if (sync_stream && external_stream_) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
Expand Down Expand Up @@ -2802,6 +2858,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_context->setDeviceMemory((*context_memory).get());
}

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
// Start CUDA graph capture.
// Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
// current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
cuda_graph_->SetStream(stream);
CaptureBegin();
}
#endif

// Run TRT inference
if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
Expand Down Expand Up @@ -2829,6 +2896,25 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}
}
}

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
// End CUDA graph capture.
// Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
// above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc,
// which might end up with many cuda graphs are captured by multiple threads if running with multithreading.
// It's safe to start/end CUDA graph capture in compute_func() here since the whole fucntion is protected by the lock_guard().
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
if (cuda_graph_enable_ && !IsGraphCaptured()) {
if (IsGraphCaptureAllowed()) {
CaptureEnd();
// CUDA work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(ReplayGraph());
} else {
IncrementRegularRunCountBeforeGraphCapture();
}
}
#endif

return Status::OK();
};

Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "NvInfer.h"
#include "NvOnnxParser.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/cuda/cuda_graph.h"
#include "tensorrt_execution_provider_info.h"

namespace onnxruntime {
Expand Down Expand Up @@ -42,6 +43,7 @@ static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_P
static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES";
static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES";
static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES";
static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -135,6 +137,7 @@ struct TensorrtFuncState {
int auxiliary_streams = -1;
bool filter_tactic_sources = false;
nvinfer1::TacticSources tactic_sources;
bool cuda_graph_enable = 0;
};

// Logical device representation.
Expand All @@ -159,6 +162,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {

void RegisterAllocator(AllocatorManager& allocator_manager) override;

Status OnRunStart() override;

Status OnRunEnd(bool sync_stream) override;

ProviderOptions GetProviderOptions() const override {
Expand All @@ -171,6 +176,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {

OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
#endif

private:
TensorrtExecutionProviderInfo info_;
bool external_stream_ = false;
Expand Down Expand Up @@ -209,6 +220,14 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool timing_cache_enable_ = false;
bool force_timing_cache_match_ = false;
bool detailed_build_log_ = false;
bool cuda_graph_enable_ = false;

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
std::unique_ptr<CUDAGraph> cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph)
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
#endif

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
Expand Down Expand Up @@ -259,5 +278,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {

/**Check whether all the nodes of subgraph are supported*/
bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const;

#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
void IncrementRegularRunCountBeforeGraphCapture();
#endif
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths";
constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes";
constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes";
constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes";
constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable";
} // namespace provider_option_names
} // namespace tensorrt

Expand Down Expand Up @@ -91,6 +92,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable)
.Parse(options)); // add new provider option here.

return info;
Expand Down Expand Up @@ -129,6 +131,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)},
{tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)},
{tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
};
return options;
}
Expand Down Expand Up @@ -167,6 +170,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)},
{tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)},
{tensorrt::provider_option_names::kTacticSources, kTacticSources_},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct TensorrtExecutionProviderInfo {
std::string profile_min_shapes{""};
std::string profile_max_shapes{""};
std::string profile_opt_shapes{""};
bool cuda_graph_enable{false};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct Tensorrt_Provider : Provider {
info.profile_min_shapes = options.trt_profile_min_shapes == nullptr ? "" : options.trt_profile_min_shapes;
info.profile_max_shapes = options.trt_profile_max_shapes == nullptr ? "" : options.trt_profile_max_shapes;
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;

common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
Expand Down Expand Up @@ -229,6 +230,8 @@ struct Tensorrt_Provider : Provider {
dest[str_size] = '\0';
trt_options.trt_profile_opt_shapes = (const char*)dest;
}

trt_options.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
26 changes: 16 additions & 10 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1521,38 +1521,44 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());

// Currently only the CUDA EP is considered.
// Currently CUDA graph is only considered by CUDA EP and TRT EP.
// If the CUDA EP is part of the providers list for this session AND
// The CUDA EP is configured to do a graph capture AND
// All the graph nodes have been assigned to the CUDA EP,
// Then the CUDA EP is cached for triggering a ReplayGraph() in Run().
auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider);
if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) {
if (cuda_ep->IsGraphCaptureEnabled()) {
// Same logic is applied to TRT EP.
std::vector<const char*> cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider};

for (auto& it : cuda_graph_support_ep_list) {
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
auto* target_ep = execution_providers_.Get(it);

if (target_ep && target_ep->IsGraphCaptureEnabled()) {
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";
<< "as the model has control flow nodes which can't be supported by CUDA Graphs.";

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, target_ep->Type())) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all the graph nodes have not been partitioned to the CUDA EP.";
<< "as all the graph nodes have not been assigned to "
<< target_ep->Type();

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));

"as all the graph nodes have not been assigned to " +
target_ep->Type()));
} else {
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep);
break; // Make sure only one ep can run CUDA graph.
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti
trt_options_converted.trt_profile_min_shapes = "";
trt_options_converted.trt_profile_max_shapes = "";
trt_options_converted.trt_profile_opt_shapes = "";
trt_options_converted.trt_cuda_graph_enable = 0;

return trt_options_converted;
}
Expand Down Expand Up @@ -1668,6 +1669,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT
options->trt_profile_min_shapes = nullptr;
options->trt_profile_max_shapes = nullptr;
options->trt_profile_opt_shapes = nullptr;
options->trt_cuda_graph_enable = false;
*out = options.release();
return nullptr;
#else
Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
nullptr,
nullptr,
nullptr,
nullptr};
0};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
Expand Down Expand Up @@ -598,6 +598,14 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n");
}
} else if (option.first == "trt_cuda_graph_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_cuda_graph_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_cuda_graph_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down
Loading