You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/design/torch_compile.md
+5-4Lines changed: 5 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -4,6 +4,9 @@ In vLLM's V1 architecture, `torch.compile` is enabled by default and is a critic
4
4
5
5
Throughout the example, we will run a common Llama model, and turn on debug level logging to show all the details. The command to be used is `VLLM_LOGGING_LEVEL=DEBUG vllm serve meta-llama/Llama-3.2-1B`.
6
6
7
+
!!! note
8
+
For more information and the latest progress of `torch.compile` integration, see this [Blog Post](https://blog.vllm.ai/2025/08/20/torch-compile.html).
9
+
7
10
## Compilation Cache
8
11
9
12
In the very verbose logs, we can see:
@@ -133,7 +136,7 @@ Unfortunately, because auto-tuning takes quite a long time (from seconds to minu
133
136
134
137
## Cudagraph Capture
135
138
136
-
vLLM's V1 architecture uses piecewise cudagraph. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on a common observation: computation between attentions are usually token-wise and easy to deal with for cudagraph; while the attention operation is non-trivial to be cudagraph compatible. Thus, by running the attention operation in eager mode while the rest operations in cudagraph, we keep the flexibility of the attention operation.
139
+
vLLM's V1 architecture uses piecewise cudagraph that aligns with the piecewise compilation. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on a common observation: computation between attentions are usually token-wise and easy to deal with for cudagraph; while the attention operation is non-trivial to be cudagraph compatible. Thus, by running the attention operation in eager mode while the rest operations in cudagraph, we keep the flexibility of the attention operation.
137
140
138
141
The piecewise cudagraph also has fine-grained memory management. The purpose is to only exclude the attention kernel from cudagraph, while keeping all the rest modules and the memory allocation operations in the cudagraph. This is why the attention operation in V1 has the output tensor as the input of the attention.
139
142
@@ -150,6 +153,4 @@ Then it will only capture cudagraph for the specified sizes. It can be useful to
150
153
151
154
### Full Cudagraph capture
152
155
153
-
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`.
154
-
155
-
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.
156
+
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models or MOEs. See [CUDA Graphs](cuda_graphs.md) for more details.
0 commit comments