Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CUDA Graph #9978

Merged
merged 20 commits into from
Mar 7, 2022
Merged

Support CUDA Graph #9978

merged 20 commits into from
Mar 7, 2022

Conversation

feihugis
Copy link
Contributor

@feihugis feihugis commented Dec 9, 2021

Description

This PR wants to support the feature of CUDA Graph. This feature can significantly reduce the CPU overhead of calling CUDA APIs by submitting the entire graph to the GPU with a single call to cudaGraphLaunch.

Motivation and Context

  • Why is this change required? What problem does it solve?
    This feature is pretty helpful to reduce the model latency, especially for the online inference, when the above CPU overhead is a bottleneck. For example, it can reduce the 95% latency of the transformer-based online inference model (with 148 millions of parameters) from 4.3ms to 2.1ms.

@hariharans29
Copy link
Member

Hi @feihugis : Thanks for this contribution.

I had a question about the transformer model you are referring to in the description. From what I read and understand about CUDA Graphs, it cannot support dynamic control flow meaning any ONNX model that has control flow ops (If, Loop, and Scan) cannot be captured as a CUDA Graph. The same is called out in PyTorch as well (https://pytorch.org/docs/master/notes/cuda.html#constraints). So this is something that needs to be explicitly disallowed in the capture phase. Does the model you see gains for using this feature have no control flow nodes in it ?

@weixingzhang
Copy link
Contributor

Two ways of supporting CUDA Graph. 1) using graph capture APIs 2) building graph directly with APIs such as cudaGraphCreate/cudaGraphAddNode. Since the input of ORT is ONNX graph, one of thinking is that the cuda graph can be built in ORT directly with the way #2 based on ONNX graph instead of using capture APIs.

@hariharans29
Copy link
Member

hariharans29 commented Jan 27, 2022

Two ways of supporting CUDA Graph. 1) using graph capture APIs 2) building graph directly with APIs such as cudaGraphCreate/cudaGraphAddNode. Since the input of ORT is ONNX graph, one of thinking is that the cuda graph can be built in ORT directly with the way #2 based on ONNX graph instead of using capture APIs.

I would have thought graph capture APIs is just way simpler (this PR is essentially that and PyTorch seems to be using capture as well). Is there any advantage behind using (2) ?

@feihugis
Copy link
Contributor Author

I had a question about the transformer model you are referring to in the description. From what I read and understand about CUDA Graphs, it cannot support dynamic control flow meaning any ONNX model that has control flow ops (If, Loop, and Scan) cannot be captured as a CUDA Graph. The same is called out in PyTorch as well (https://pytorch.org/docs/master/notes/cuda.html#constraints). So this is something that needs to be explicitly disallowed in the capture phase. Does the model you see gains for using this feature have no control flow nodes in it ?

Thanks @hariharans29 for your review. The two transformer models I tested do not have control flow nodes in it. The constrains in PyTorch apply to here as well. If any constrains happen during capturing, either the errors will be raised or the capturing graph may not produce the correct results. The first scenarios (raise errors) seems to be OK; the second scenarios seems to be hard to identify. It seems better to let users to handle it. The control flow ops is the case in the second scenario. As CUDA graph captures the kernels enqueued in the stream, it may still work, but just capture one branch. Therefore, the results may not be correct for all the inputs (like pytorch tracing). Some docs need to be added to explain the limitations. Explicitly disallowing these cases seems to be too strict. For example, users can capture different graphs for different branches of the control flow.

@feihugis
Copy link
Contributor Author

Two ways of supporting CUDA Graph. 1) using graph capture APIs 2) building graph directly with APIs such as cudaGraphCreate/cudaGraphAddNode. Since the input of ORT is ONNX graph, one of thinking is that the cuda graph can be built in ORT directly with the way #2 based on ONNX graph instead of using capture APIs.

I would have thought graph capture APIs is just way simpler (this PR is essentially that and PyTorch seems to be using capture as well). Is there any advantage behind using (2) ?

Thanks @weixingzhang for your suggestions. There is a limitation in the second way: it is not easy to get the actual excuted CUDA kernel from the ONNX graph. For example, for MatMul node, the actual kernel selected by CUDA will be different for different input dtypes and shapes.

@hariharans29
Copy link
Member

I had a question about the transformer model you are referring to in the description. From what I read and understand about CUDA Graphs, it cannot support dynamic control flow meaning any ONNX model that has control flow ops (If, Loop, and Scan) cannot be captured as a CUDA Graph. The same is called out in PyTorch as well (https://pytorch.org/docs/master/notes/cuda.html#constraints). So this is something that needs to be explicitly disallowed in the capture phase. Does the model you see gains for using this feature have no control flow nodes in it ?

Thanks @hariharans29 for your review. The two transformer models I tested do not have control flow nodes in it. The constrains in PyTorch apply to here as well. If any constrains happen during capturing, either the errors will be raised or the capturing graph may not produce the correct results. The first scenarios (raise errors) seems to be OK; the second scenarios seems to be hard to identify. It seems better to let users to handle it. The control flow ops is the case in the second scenario. As CUDA graph captures the kernels enqueued in the stream, it may still work, but just capture one branch. Therefore, the results may not be correct for all the inputs (like pytorch tracing). Some docs need to be added to explain the limitations. Explicitly disallowing these cases seems to be too strict. For example, users can capture different graphs for different branches of the control flow.

"Explicitly disallowing these cases seems to be too strict. For example, users can capture different graphs for different branches of the control flow."

But atleast it will reduce some associated maintenance overhead and we wouldn't have to spend time debugging silent errors associated with executing kernels from the wrong subgraph of a control flow node and as far as I can tell even your design currently only allows capturing one graph per CUDA EP and even if we did allow capturing multiple graphs per EP we still wouldn't know which graph instance to execute for a new input (if the model had dynamic control flow nodes).

@feihugis
Copy link
Contributor Author

But atleast it will reduce some associated maintenance overhead and we wouldn't have to spend time debugging silent errors associated with executing kernels from the wrong subgraph of a control flow node and as far as I can tell even your design currently only allows capturing one graph per CUDA EP and even if we did allow capturing multiple graphs per EP we still wouldn't know which graph instance to execute for a new input (if the model had dynamic control flow nodes).

@hariharans29 Got your point now! I will add a check to explicitly disable the cases that CUDA graph could not fully support. For the support of multiple CUDA graph, at the beginning I thought multiple sessions can be created for different graphs, but did not evaluate it yet and not sure creating multiple sessions will be allowed. Yes, you are right. The multiple graphs still could not handle the control flow very well.

@feihugis feihugis force-pushed the cuda_graph branch 3 times, most recently from 002559b to fdaebfa Compare February 1, 2022 00:12
@feihugis feihugis force-pushed the cuda_graph branch 2 times, most recently from 6495dfa to 1e1c41c Compare February 4, 2022 22:33
@hariharans29
Copy link
Member

/azp run Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@hariharans29
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@hariharans29
Copy link
Member

/azp run onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@hariharans29
Copy link
Member

hariharans29 commented Mar 4, 2022

@pranavsharma - Could you please take another look at Fei's recent changes based on the offline discussion ?

onnxruntime/core/session/inference_session.h Outdated Show resolved Hide resolved
}

Status ReplayGraph() {
if (cached_execution_provider_for_graph_replay_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be good to check for IsGraphCaptured here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORT_ENFORCE(IsGraphCaptured()); is added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the EP's ReplayGraph() already enforce that ? This seems redundant.

onnxruntime/core/session/inference_session.h Outdated Show resolved Hide resolved
onnxruntime/core/session/inference_session.cc Outdated Show resolved Hide resolved
onnxruntime/core/session/inference_session.cc Outdated Show resolved Hide resolved
@feihugis
Copy link
Contributor Author

feihugis commented Mar 4, 2022

@hariharans29 @pranavsharma Thanks for the reviewing! The comments have been addressed.

Copy link
Contributor

@pranavsharma pranavsharma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍 @hariharans29 I believe you're going to have a follow up PR with some documentation?

@hariharans29
Copy link
Member

hariharans29 commented Mar 5, 2022

LGTM 👍 @hariharans29 I believe you're going to have a follow up PR with some documentation?

Yes, was waiting until the PR is ready to be merged (as the design was in a state of constant flux). I will add the documentation next.

@hariharans29
Copy link
Member

/azp run Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@hariharans29
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@hariharans29
Copy link
Member

/azp run onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@hariharans29 hariharans29 merged commit 60acfd3 into microsoft:master Mar 7, 2022
lavanyax pushed a commit to intel/onnxruntime that referenced this pull request Mar 29, 2022
chilo-ms added a commit that referenced this pull request Jun 21, 2023
CUDA EP already supports [CUDA
graph](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs),
also we observed some models can benefit from using CUDA graph with
`trtexec`. Therefore, this PR enables the CUDA graph support for TRT EP.

The implementation is based on
#9978 with the same
[constraints](#9978) as
below:

- Models with control-flow ops (i.e. If, Loop and Scan ops) are not
supported.
- Usage of CUDA Graphs is limited to models where-in all the model ops
(graph nodes) can be partitioned to the TRT EP.
- The input/output types of models need to be tensors.
- Shapes of inputs/outputs cannot change across inference calls.
- IObinding is required.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants