Skip to content

Commit

Permalink
Fixes #656, fixes #656: Stream capture functionality work:
Browse files Browse the repository at this point in the history
* Now offering global capture mode as a default (so the user doesn't need to always specify it explicitly)
* Can now start and stop capturing using stand-alone functions in addition to stream_t methods.
  • Loading branch information
eyalroz committed Jun 21, 2024
1 parent 6089dee commit cf1f1df
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ void cudaGraphsUsingStreamCapture(
auto reduce_output_memset_event = cuda::event::create(device);
auto final_result_memset_event = cuda::event::create(device);

stream_1.begin_capture(cuda::stream::capture::mode_t::global);
stream_1.begin_capture();

stream_1.enqueue.event(fork_stream_event);
stream_2.enqueue.wait(fork_stream_event);
Expand Down
13 changes: 10 additions & 3 deletions src/cuda/api/multi_wrapper_impls/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,21 @@ inline ::std::string describe(graph::instance::update_status_t update_status, op
graph::instance::detail_::describe(update_status, node.value().handle(), node.value().containing_graph_handle());
}

inline graph::template_t stream_t::end_capture() const
namespace stream {
namespace capture {

graph::template_t end(const cuda::stream_t& stream)
{
graph::template_::handle_t new_graph;
auto status = cuStreamEndCapture(handle_, &new_graph);
throw_if_error_lazy(status, "Completing the capture of operations into a graph on " + stream::detail_::identify(*this));
auto status = cuStreamEndCapture(stream.handle(), &new_graph);
throw_if_error_lazy(status,
"Completing the capture of operations into a graph on " + stream::detail_::identify(stream));
return graph::template_::wrap(new_graph, do_take_ownership);
}

} // namespace capture
} // namespace stream

inline void stream_t::enqueue_t::graph_launch(const graph::instance_t& graph_instance) const
{
graph::launch(graph_instance, associated_stream);
Expand Down
28 changes: 23 additions & 5 deletions src/cuda/api/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ namespace capture {

inline state_t state(const stream_t& stream);

/**
* @brief Have a stream capture operations enqueued from now on, for later generation
* of an execution graph.
*
* @note See also @ref graph::template_t
*/
void begin(const cuda::stream_t& stream, stream::capture::mode_t mode = cuda::stream::capture::mode_t::global);
graph::template_t end(const cuda::stream_t& stream);

} // namespace capture

inline bool is_capturing(const stream_t& stream)
Expand Down Expand Up @@ -855,10 +864,9 @@ class stream_t {
*
* @note See also @ref graph::template_t
*/
void begin_capture(stream::capture::mode_t mode) const {
context::current::detail_::scoped_override_t set_context_for_this_scope(context_handle_);
auto status = cuStreamBeginCapture(handle_, static_cast<CUstreamCaptureMode>(mode));
throw_if_error_lazy(status, "Failed beginning to capture on " + stream::detail_::identify(*this));
void begin_capture(stream::capture::mode_t mode = cuda::stream::capture::mode_t::global) const
{
stream::capture::begin(*this, mode);
}

/**
Expand All @@ -872,7 +880,10 @@ class stream_t {
* @return A CUDA execution graph template, comprising of all operations enqueued on this stream
* between the last invocation of @ref begin_capture and the invocation of this one.
*/
graph::template_t end_capture() const;
graph::template_t end_capture() const
{
return stream::capture::end(*this);
}
#endif // CUDA_VERSION >= 10000

protected: // constructor
Expand Down Expand Up @@ -1130,6 +1141,13 @@ state_t state(const stream_t& stream)
return static_cast<state_t>(capture_status);
}

void begin(const cuda::stream_t& stream, stream::capture::mode_t mode)
{
context::current::detail_::scoped_override_t set_context_for_this_scope(stream.context_handle());
auto status = cuStreamBeginCapture(stream.handle(), static_cast<CUstreamCaptureMode>(mode));
throw_if_error_lazy(status, "Failed beginning to capture on " + stream::detail_::identify(stream));
}

} // namespace capture
#endif // CUDA_VERSION >= 10000

Expand Down

0 comments on commit cf1f1df

Please sign in to comment.