Skip to content

Commit 30d6466

Browse files
[BugFix] Fix Eagle IndexError: list index out of range for even num_speculative_tokens (#29102)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent e9af6ba commit 30d6466

File tree

3 files changed

+37
-20
lines changed

3 files changed

+37
-20
lines changed

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,14 @@ def __init__(
748748
# being captured which can trigger edge cases that we don't handle yet.
749749
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
750750

751+
# Make sure we have atleast one cudagraph large enough for a single decode.
752+
if (speculative_config := kwargs.get("speculative_config")) and (
753+
num_speculative_tokens := speculative_config["num_speculative_tokens"]
754+
):
755+
kwargs["compilation_config"]["cudagraph_capture_sizes"].append(
756+
num_speculative_tokens + 1
757+
)
758+
751759
with init_ctx:
752760
self.llm = LLM(
753761
model=model_name,

vllm/config/compilation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -950,14 +950,18 @@ def adjust_cudagraph_sizes_for_spec_decode(
950950
)
951951
)
952952

953+
if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size:
954+
# if one valid but would be round_down use that
955+
rounded_sizes = [multiple_of]
956+
953957
if len(rounded_sizes) == 0:
954-
logger.warning(
955-
"No valid cudagraph sizes after rounding to multiple of "
956-
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
957-
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
958-
multiple_of,
958+
raise ValueError(
959+
f"No valid cudagraph sizes after rounding to multiple of {multiple_of} "
960+
f"(num_speculative_tokens + 1 or tp if sequence parallelism is enabled)"
961+
f" please adjust num_speculative_tokens ({uniform_decode_query_len - 1}"
962+
f") or max_cudagraph_capture_size ({self.max_cudagraph_capture_size})"
963+
f" or cudagraph_capture_sizes ({self.cudagraph_capture_sizes})"
959964
)
960-
return
961965

962966
self.max_cudagraph_capture_size = rounded_sizes[-1]
963967
self.cudagraph_capture_sizes = rounded_sizes

vllm/v1/spec_decode/eagle.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def __init__(
8686

8787
self.use_cuda_graph = False
8888

89-
compilation_config = self.vllm_config.compilation_config
90-
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
91-
cudagraph_mode = compilation_config.cudagraph_mode
89+
self.compilation_config = self.vllm_config.compilation_config
90+
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
91+
cudagraph_mode = self.compilation_config.cudagraph_mode
9292
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
9393
CUDAGraphMode.PIECEWISE
9494
):
@@ -103,13 +103,6 @@ def __init__(
103103
and not self.speculative_config.enforce_eager
104104
)
105105

106-
self.cudagraph_batch_sizes = (
107-
(sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes))
108-
if self.use_cuda_graph
109-
else []
110-
)
111-
112-
self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes)
113106
# persistent buffers for cuda graph
114107
self.input_ids = torch.zeros(
115108
self.max_num_tokens, dtype=torch.int32, device=device
@@ -276,7 +269,10 @@ def propose(
276269
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
277270

278271
cudagraph_runtime_mode = CUDAGraphMode.NONE
279-
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
272+
if (
273+
self.use_cuda_graph
274+
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
275+
):
280276
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
281277
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
282278
else:
@@ -366,7 +362,10 @@ def propose(
366362
# Generate the remaining draft tokens.
367363
draft_token_ids_list = [draft_token_ids]
368364

369-
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
365+
if (
366+
self.use_cuda_graph
367+
and batch_size <= self.compilation_config.max_cudagraph_capture_size
368+
):
370369
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
371370
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
372371
else:
@@ -777,7 +776,10 @@ def propose_tree(
777776
self.positions[:num_tokens] = tree_positions.view(-1)
778777
self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
779778

780-
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
779+
if (
780+
self.use_cuda_graph
781+
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
782+
):
781783
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
782784
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
783785
else:
@@ -1114,7 +1116,10 @@ def dummy_run(
11141116
) -> None:
11151117
# Determine if CUDA graphs should be used for this run.
11161118
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1117-
if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
1119+
if (
1120+
cudagraphs_enabled
1121+
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
1122+
):
11181123
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
11191124

11201125
with set_forward_context(

0 commit comments

Comments
 (0)