Skip to content

Commit c7515da

Browse files
galvpytorchmergebot
authored andcommitted
Implement cuda graphs implementation of torch.cond and torch.while_loop (pytorch#140979)
This is a new PR for pytorch#130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to isaacs/github#361 I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534 Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that. Pull Request resolved: pytorch#140979 Approved by: https://github.com/eqy, https://github.com/eellison
1 parent e3839bd commit c7515da

File tree

22 files changed

+1145
-29
lines changed

22 files changed

+1145
-29
lines changed

aten/src/ATen/cuda/CUDAGeneratorImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
347347
*/
348348
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
349349
at::cuda::assertNotCapturing(
350-
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
350+
"Please ensure to utilize the CUDAGeneratorImpl::graphsafe_set_state method during capturing.");
351351
static const size_t seed_size = sizeof(uint64_t);
352352
static const size_t offset_size = sizeof(int64_t);
353353
static const size_t total_size = seed_size + offset_size;

aten/src/ATen/cuda/CUDAGraph.cpp

Lines changed: 317 additions & 8 deletions
Large diffs are not rendered by default.

aten/src/ATen/cuda/CUDAGraph.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <ATen/cuda/CUDAGraph.h>
2+
#include <ATen/cuda/Exceptions.h>
3+
4+
namespace at::cuda {
5+
6+
namespace {
7+
8+
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
9+
__global__ void set_conditional_handle_kernel(
10+
cudaGraphConditionalHandle handle,
11+
const bool* value) {
12+
cudaGraphSetConditional(handle, *value);
13+
}
14+
#endif
15+
}
16+
17+
void CUDAGraph::set_conditional_handle(
18+
cudaGraphConditionalHandle handle,
19+
const Tensor& scalar_cuda_pred_tensor) {
20+
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
21+
set_conditional_handle_kernel<<<1, 1, 0, getCurrentCUDAStream()>>>(
22+
handle, scalar_cuda_pred_tensor.const_data_ptr<bool>());
23+
C10_CUDA_KERNEL_LAUNCH_CHECK();
24+
#else
25+
AT_ERROR("not allowed");
26+
return;
27+
#endif
28+
}
29+
30+
} // namespace at::cuda

aten/src/ATen/cuda/CUDAGraph.h

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@
33
#include <ATen/Tensor.h>
44
#include <c10/core/Device.h>
55
#include <c10/cuda/CUDAGraphsC10Utils.h>
6+
#include <c10/cuda/CUDAGuard.h>
67
#include <c10/cuda/CUDAStream.h>
78
#include <c10/util/flat_hash_map.h>
89

10+
#include <limits>
11+
#include <stack>
12+
13+
#if defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
14+
// this type is not defined until CUDA 12.4, but we use it as a
15+
// parameter type and return type in some below functions, so we give
16+
// it the same definition as in CUDA 12.4.
17+
typedef unsigned long long cudaGraphConditionalHandle;
18+
#endif // defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
19+
920
namespace at {
1021

1122
struct Generator;
@@ -14,6 +25,9 @@ struct CUDAGeneratorState;
1425

1526
namespace cuda {
1627

28+
using UniquePtrExternalCudaStream =
29+
std::unique_ptr<cudaStream_t, void (*)(cudaStream_t*)>;
30+
1731
// Standalone way to get a unique mempool id usable as a pool=... argument
1832
// to CUDAGraph::capture_begin
1933
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
@@ -22,6 +36,26 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
2236
CUDAGraph();
2337
~CUDAGraph();
2438

39+
// Copy and move constructors and assignments are disabled. These
40+
// were disabled because pybind11 believed that CUDAGraph was copy
41+
// constructable because
42+
// pybind11::is_copy_constructible<CUDAGraph>::value originally
43+
// evaluated to true. However, it cannot generate a copy constructor
44+
// because CUDAGeneratorState, one of CUDAGraph's members, is an
45+
// incomplete type unless CUDAGeneratorImpl.h is included. However,
46+
// that would create a circular dependency between
47+
// CUDAGeneratorImpl.h and CUDAGraph.h. Disabling the copy and move
48+
// constructors is the most straightforward way to prevent pybind11
49+
// from trying to generate default implementations of them.
50+
//
51+
// We needed pybind11 to return a reference to a CUDAGraph as part
52+
// of wrapping CUDAGraph::get_currently_capturing_graph, which
53+
// unearthed the above problem.
54+
CUDAGraph(const CUDAGraph&) = delete;
55+
CUDAGraph& operator=(const CUDAGraph&) = delete;
56+
CUDAGraph(CUDAGraph&& other) = delete;
57+
CUDAGraph& operator=(CUDAGraph&& other) = delete;
58+
2559
static void inc_pending_event_queries();
2660
static void dec_pending_event_queries();
2761
static int num_pending_event_queries();
@@ -38,6 +72,19 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
3872
void enable_debug_mode();
3973
void debug_dump(const std::string& debug_path);
4074

75+
static CUDAGraph* get_currently_capturing_graph();
76+
void begin_capture_to_if_node(const Tensor& scalar_cuda_pred_tensor);
77+
cudaGraphConditionalHandle begin_capture_to_while_loop_node(
78+
const Tensor& scalar_cuda_pred_tensor);
79+
void end_capture_to_conditional_node();
80+
static void set_conditional_handle(
81+
cudaGraphConditionalHandle handle,
82+
const Tensor& scalar_cuda_pred_tensor);
83+
84+
private:
85+
std::function<bool(cudaStream_t)> create_allocate_filter();
86+
std::function<bool(cudaStream_t)> create_child_allocate_filter();
87+
4188
protected:
4289
cudaGraph_t graph_ = nullptr;
4390
cudaGraphExec_t graph_exec_ = nullptr;
@@ -54,7 +101,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
54101

55102
// the ID assigned by cuda during graph capture,
56103
// used to identify when a stream is participating in capture
57-
CaptureId_t capture_id_ = -1;
104+
CaptureId_t capture_id_ = std::numeric_limits<CaptureId_t>::max();
58105

59106
// uuid used to request a particular private mempool from CUDACachingAllocator.
60107
// By default, this will be set to {id_, 0}.
@@ -85,6 +132,15 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
85132
// init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor
86133
static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1;
87134
c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
135+
136+
cudaStreamCaptureMode capture_mode_{};
137+
138+
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
139+
std::stack<std::pair<at::cuda::CUDAStreamGuard, UniquePtrExternalCudaStream>>
140+
conditional_node_streams_;
141+
std::stack<CaptureId_t> conditional_graph_capture_streams_ids_;
142+
std::vector<cudaGraph_t> descendent_graphs_;
143+
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
88144
};
89145

90146
} // namespace cuda

docs/source/notes/cuda.rst

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,10 @@ and you suspect its runtime is at least somewhat CPU-limited.
929929
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
930930
.. _cudaGraphLaunch:
931931
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
932+
.. _issue 144787:
933+
https://github.com/pytorch/pytorch/issues/144787#issuecomment-2606480564
934+
.. _conditional nodes:
935+
https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
932936

933937
PyTorch API
934938
^^^^^^^^^^^
@@ -1017,6 +1021,9 @@ Violating any of these will likely cause a runtime error:
10171021
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
10181022
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
10191023
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
1024+
* Dynamic control flow (based on CPU or GPU data) is prohibited, unless it is based on GPU data and implemented via higher order operators
1025+
torch.cond() and torch.while_loop(). See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.
1026+
10201027

10211028

10221029
Violating any of these will likely cause silent numerical errors or undefined behavior:
@@ -1025,7 +1032,6 @@ Violating any of these will likely cause silent numerical errors or undefined be
10251032
* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
10261033
* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
10271034
* Every replay reads from and writes to the same (virtual) memory addresses.
1028-
* Dynamic control flow (based on CPU or GPU data) is prohibited.
10291035
* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
10301036
has the same size and layout in every replay.
10311037
* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
@@ -1334,3 +1340,45 @@ If, in the live workload, your callables will run in an order that occasionally
13341340
or if they'll run concurrently, passing them as a tuple to a single invocation of
13351341
:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
13361342
:func:`~torch.cuda.make_graphed_callables` separately for each one.
1343+
1344+
.. _graph-data-dependent-control-flow:
1345+
1346+
Data Dependent Control Flow
1347+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
1348+
1349+
Data-dependent control flow can with cuda graphs in limited cases if
1350+
the control flow is implemented using torch.cond() or
1351+
torch.while_loop(). If your function uses these functions, compiling
1352+
it with the "cudagraphs" backend will enable control flow in the
1353+
resulting cuda graph via `conditional nodes`_.
1354+
1355+
Unfortunately, eager mode execution does not work due to reasons
1356+
described in `issue 144787`_.
1357+
Support for inductor backend to torch.compile is not available yet, but there is no fundamental blocker.
1358+
1359+
An example of using the cudagraphs backend to torch.compile on code
1360+
using torch.cond is demonstrated below::
1361+
1362+
import torch
1363+
1364+
def true_fn(x):
1365+
return x.sin()
1366+
1367+
def false_fn(x):
1368+
return x.cos()
1369+
1370+
x = torch.randn(4, device="cuda", requires_grad=False)
1371+
pred = torch.tensor(False, device="cuda", requires_grad=False)
1372+
def foo(pred, x):
1373+
with torch.inference_mode():
1374+
return torch.cond(pred, true_fn, false_fn, [x])
1375+
1376+
# First call will run eager for warmup, second call will do graph
1377+
# capture followed by graph replay, third call and beyond will do
1378+
# just graph replay.
1379+
compiled_foo = torch.compile(foo, backend="cudagraphs")
1380+
for i in range(3):
1381+
y = compiled_foo(pred, x)
1382+
1383+
# will output x.sin()
1384+
y = compiled_foo(~pred, x)

docs/source/torch.compiler_cudagraph_trees.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraph
1313

1414
CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
1515

16-
- Control Flow is not possible
16+
- Arbitrary Control Flow is not possible (However, control flow expressed via torch.cond() and torch.while_loop() can be captured in a CUDA Graph. See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.)
1717
- Kernels which trigger host to device syncs (such as .item()) errors
1818
- All input arguments to kernels are fixed to what they were recorded
1919
- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change

0 commit comments

Comments
 (0)