@@ -86,9 +86,9 @@ def __init__(
8686
8787 self .use_cuda_graph = False
8888
89- compilation_config = self .vllm_config .compilation_config
90- if compilation_config .mode == CompilationMode .VLLM_COMPILE :
91- cudagraph_mode = compilation_config .cudagraph_mode
89+ self . compilation_config = self .vllm_config .compilation_config
90+ if self . compilation_config .mode == CompilationMode .VLLM_COMPILE :
91+ cudagraph_mode = self . compilation_config .cudagraph_mode
9292 if cudagraph_mode != CUDAGraphMode .NONE and not cudagraph_mode .has_mode (
9393 CUDAGraphMode .PIECEWISE
9494 ):
@@ -103,13 +103,6 @@ def __init__(
103103 and not self .speculative_config .enforce_eager
104104 )
105105
106- self .cudagraph_batch_sizes = (
107- (sorted (self .vllm_config .compilation_config .cudagraph_capture_sizes ))
108- if self .use_cuda_graph
109- else []
110- )
111-
112- self .use_cuda_graph = self .use_cuda_graph and bool (self .cudagraph_batch_sizes )
113106 # persistent buffers for cuda graph
114107 self .input_ids = torch .zeros (
115108 self .max_num_tokens , dtype = torch .int32 , device = device
@@ -276,7 +269,10 @@ def propose(
276269 per_layer_attn_metadata [layer_name ] = draft_indexer_metadata
277270
278271 cudagraph_runtime_mode = CUDAGraphMode .NONE
279- if self .use_cuda_graph and num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
272+ if (
273+ self .use_cuda_graph
274+ and num_tokens <= self .compilation_config .max_cudagraph_capture_size
275+ ):
280276 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
281277 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
282278 else :
@@ -366,7 +362,10 @@ def propose(
366362 # Generate the remaining draft tokens.
367363 draft_token_ids_list = [draft_token_ids ]
368364
369- if self .use_cuda_graph and batch_size <= self .cudagraph_batch_sizes [- 1 ]:
365+ if (
366+ self .use_cuda_graph
367+ and batch_size <= self .compilation_config .max_cudagraph_capture_size
368+ ):
370369 input_batch_size = self .vllm_config .pad_for_cudagraph (batch_size )
371370 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
372371 else :
@@ -777,7 +776,10 @@ def propose_tree(
777776 self .positions [:num_tokens ] = tree_positions .view (- 1 )
778777 self .hidden_states [:num_tokens ] = tree_hidden_states .view (num_tokens , - 1 )
779778
780- if self .use_cuda_graph and num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
779+ if (
780+ self .use_cuda_graph
781+ and num_tokens <= self .compilation_config .max_cudagraph_capture_size
782+ ):
781783 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
782784 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
783785 else :
@@ -1114,7 +1116,10 @@ def dummy_run(
11141116 ) -> None :
11151117 # Determine if CUDA graphs should be used for this run.
11161118 cudagraphs_enabled = use_cudagraphs and self .use_cuda_graph
1117- if cudagraphs_enabled and num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
1119+ if (
1120+ cudagraphs_enabled
1121+ and num_tokens <= self .compilation_config .max_cudagraph_capture_size
1122+ ):
11181123 num_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
11191124
11201125 with set_forward_context (
0 commit comments