Skip to content

Commit d94e302

Browse files
[V1] Add tree drafting tests for eagle spec decoding (#22705)
Signed-off-by: Giancarlo Delfin <[email protected]>
1 parent 3f52738 commit d94e302

File tree

4 files changed

+178
-55
lines changed

4 files changed

+178
-55
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 157 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from typing import Optional
45
from unittest import mock
56

67
import pytest
@@ -23,20 +24,30 @@
2324
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
2425

2526

26-
def _create_proposer(method: str, k: int) -> EagleProposer:
27+
def _create_proposer(
28+
method: str,
29+
num_speculative_tokens: int,
30+
speculative_token_tree: Optional[list[tuple[int]]] = None,
31+
) -> EagleProposer:
2732
model_config = ModelConfig(model=model_dir,
2833
runner="generate",
2934
max_model_len=100)
3035

3136
# Choose model directory based on method
3237
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
3338

39+
spec_token_tree_str = None
40+
if speculative_token_tree is not None:
41+
assert num_speculative_tokens == len(speculative_token_tree)
42+
spec_token_tree_str = str(speculative_token_tree)
43+
3444
speculative_config = SpeculativeConfig(
3545
target_model_config=model_config,
3646
target_parallel_config=ParallelConfig(),
3747
model=draft_model_dir,
3848
method=method,
39-
num_speculative_tokens=k,
49+
num_speculative_tokens=num_speculative_tokens,
50+
speculative_token_tree=spec_token_tree_str,
4051
)
4152

4253
vllm_config = VllmConfig(
@@ -189,7 +200,7 @@ class _TargetModelStub(LlamaForCausalLM):
189200
target_model.lm_head = mock.MagicMock()
190201

191202
# Create proposer using the helper function
192-
proposer = _create_proposer(method, k=8)
203+
proposer = _create_proposer(method, num_speculative_tokens=8)
193204

194205
# Call the method under test
195206
proposer.load_model(target_model)
@@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
226237
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
227238
"multi-token eagle spec decode on current platform")
228239

240+
if (attn_backend == "TREE_ATTN"):
241+
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
242+
"because it requires special input mocking.")
243+
229244
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
230245
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
231246

@@ -378,3 +393,142 @@ def create_deterministic_logits(token_ids):
378393

379394
# Verify all tokens match our expectations
380395
assert torch.equal(result, expected_tokens)
396+
397+
398+
@pytest.mark.parametrize(
399+
"spec_token_tree",
400+
[
401+
[(0, )], # A single token
402+
[(0, ), (0, 0), (0, 0, 0)], # Chain
403+
[(0, ), (1, ), (2, )], # Parallel
404+
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
405+
(2, 1)], # Tree
406+
])
407+
def test_propose_tree(spec_token_tree):
408+
# Get GPU device.
409+
device = torch.device(current_platform.device_type)
410+
411+
# Setup test parameters.
412+
batch_size = 2
413+
seq_len_1 = 5
414+
seq_len_2 = 3
415+
total_tokens = seq_len_1 + seq_len_2
416+
vocab_size = 100
417+
seq_lens = [seq_len_1, seq_len_2]
418+
num_speculative_tokens = len(spec_token_tree)
419+
420+
# Create proposer first so we can use its actual hidden_size.
421+
proposer = _create_proposer("eagle",
422+
num_speculative_tokens,
423+
speculative_token_tree=spec_token_tree)
424+
# Get the hidden_size from the proposer to ensure consistency.
425+
hidden_size = proposer.hidden_size
426+
427+
# Helper to create deterministic logits that will produce specific tokens
428+
def create_deterministic_logits(token_ids, k: int):
429+
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
430+
for i, token_id in enumerate(token_ids):
431+
# Assign decreasing values to the k, consecutive, tokens.
432+
for j in range(k):
433+
logits[i, token_id + j] = 100.0 - j
434+
return logits
435+
436+
# Mock a model that returns deterministic logits.
437+
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
438+
439+
# Skip loading the model and replace it with a mock that returns
440+
# deterministic outputs.
441+
model_mock = mock.MagicMock()
442+
443+
# Mock the model forward calls.
444+
forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
445+
torch.zeros(total_tokens, hidden_size, device=device))]
446+
for cu_num_drafts in proposer.cu_drafts_per_level:
447+
h_logits = torch.zeros(batch_size * cu_num_drafts,
448+
hidden_size,
449+
device=device)
450+
h_states = torch.zeros(batch_size * cu_num_drafts,
451+
hidden_size,
452+
device=device)
453+
forward_returns.append((h_logits, h_states))
454+
model_mock.side_effect = forward_returns
455+
456+
# Mock the compute_logits calls.
457+
cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
458+
dtype=torch.int32,
459+
device=device)
460+
logits_returns = []
461+
for level, num_children in enumerate(proposer.child_drafts_per_level):
462+
token_ids = base_token_ids + cu_num_drafts_tensor[level]
463+
level_num_drafts = cu_num_drafts_tensor[
464+
level + 1] - cu_num_drafts_tensor[level]
465+
level_logits = []
466+
for i in range(level_num_drafts // num_children):
467+
level_logits.append(
468+
create_deterministic_logits(token_ids + i * num_children,
469+
num_children))
470+
logits_returns.append(torch.stack(level_logits, dim=1))
471+
model_mock.compute_logits.side_effect = logits_returns
472+
473+
# Assign the mock to the proposer
474+
proposer.model = model_mock
475+
476+
# Assign draft attn_layer_names since load_model is not invoked
477+
proposer.attn_layer_names = ["layer.0"]
478+
479+
# Get the tree attention metadata builder.
480+
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
481+
attn_metadata_builder = attn_metadata_builder_cls(
482+
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
483+
layer_names=proposer.attn_layer_names,
484+
vllm_config=proposer.vllm_config,
485+
device=device,
486+
)
487+
488+
# Mock runner for attention metadata building.
489+
proposer.runner = mock.MagicMock()
490+
proposer.runner.attn_groups.append([mock.MagicMock()])
491+
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
492+
493+
# Setup inputs for the proposer.
494+
target_token_ids = torch.randint(0,
495+
vocab_size, (total_tokens, ),
496+
device=device)
497+
target_positions = torch.cat([
498+
torch.arange(seq_len_1, device=device),
499+
torch.arange(seq_len_2, device=device)
500+
])
501+
target_hidden_states = torch.randn(total_tokens,
502+
hidden_size,
503+
device=device)
504+
next_token_ids = torch.randint(0,
505+
vocab_size, (batch_size, ),
506+
dtype=torch.int32,
507+
device=device)
508+
batch_spec = BatchSpec(
509+
seq_lens=seq_lens,
510+
query_lens=seq_lens,
511+
)
512+
common_attn_metadata = create_common_attn_metadata(
513+
batch_spec,
514+
block_size=16,
515+
device=device,
516+
)
517+
sampling_metadata = mock.MagicMock()
518+
519+
# Propose draft tokens.
520+
result = proposer.propose(target_token_ids=target_token_ids,
521+
target_positions=target_positions,
522+
target_hidden_states=target_hidden_states,
523+
next_token_ids=next_token_ids,
524+
common_attn_metadata=common_attn_metadata,
525+
sampling_metadata=sampling_metadata)
526+
assert result.shape == (batch_size, num_speculative_tokens)
527+
528+
# The tokens are expected to be consecutive integers starting
529+
# from the base token IDs.
530+
expected_tokens = base_token_ids[:, None] + torch.arange(
531+
num_speculative_tokens, dtype=torch.int64, device=device)
532+
533+
# Verify that the draft tokens match our expectations.
534+
assert torch.equal(result, expected_tokens)

tests/v1/spec_decode/test_max_len.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
3939
num_speculative_tokens: int, attn_backend: str):
4040
with monkeypatch.context() as m:
4141
m.setenv("VLLM_USE_V1", "1")
42-
43-
if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1:
44-
# TREE_ATTN fails the test with multi-token spec decode
45-
# TODO: Investigate why
46-
pytest.skip("TREE_ATTN fails the test")
47-
4842
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
4943

5044
if (attn_backend == "TRITON_ATTN_VLLM_V1"

vllm/v1/attention/backends/tree_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ def build_for_drafting(
236236
# Use prefill for drafting at the root level.
237237
self.tree_attn_bias = torch.empty(0)
238238
else:
239-
# Slice the tree attention bias for drafting.
240-
query_len = common_attn_metadata.max_query_len
241-
start, end = draft_index, draft_index + query_len
239+
# Slice the tree attention bias for drafting. Exclude
240+
# the root level.
241+
start, end = 1, 1 + common_attn_metadata.max_query_len
242242
self.tree_attn_bias = self.tree_attn_bias[start:end,
243243
start:end].contiguous()
244244

vllm/v1/spec_decode/eagle.py

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,6 @@ def __init__(
113113
num_drafts_per_level[level])
114114
self.child_drafts_per_level.append(num_drafts_per_level[level] //
115115
num_drafts_per_level[level - 1])
116-
# Find the first level where the tree branches off into one or more
117-
# children.
118-
self.first_branching_level = None
119-
for level in range(tree_depth):
120-
if self.cu_drafts_per_level[level] > level + 1:
121-
self.first_branching_level = level
122-
break
123116
# Precompute draft position offsets in flattened tree.
124117
self.tree_draft_pos_offsets = torch.arange(
125118
1,
@@ -209,11 +202,10 @@ def propose(
209202
logits = self.model.compute_logits(sample_hidden_states, None)
210203
positions = target_positions[last_token_indices]
211204
hidden_states = hidden_states[last_token_indices]
212-
if self.first_branching_level == 0:
213-
# Branching has occurred at the root level. Draft using tree
214-
# attention.
205+
206+
if isinstance(attn_metadata, TreeAttentionMetadata):
207+
# Draft using tree attention.
215208
draft_token_ids_list = self.propose_tree(
216-
tree_root_level=0,
217209
batch_size=batch_size,
218210
logits=logits,
219211
positions=positions,
@@ -242,11 +234,10 @@ def propose(
242234
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
243235
FlashAttentionMetadata))
244236
else:
245-
# Currently, only FlashAttention and TreeAttention support
246-
# multi-token eagle spec decode. This is because the code below
247-
# makes assumptions about attn_metadata attributes available.
248-
assert isinstance(attn_metadata,
249-
(FlashAttentionMetadata, TreeAttentionMetadata))
237+
# Currently, only FlashAttention supports multi-token eagle spec
238+
# decode. This is because the code below makes assumptions about
239+
# attn_metadata attributes available.
240+
assert isinstance(attn_metadata, FlashAttentionMetadata)
250241

251242
# Generate the remaining draft tokens.
252243
draft_token_ids_list = [draft_token_ids]
@@ -259,7 +250,7 @@ def propose(
259250
attn_metadata.num_actual_tokens = batch_size
260251
attn_metadata.max_query_len = 1
261252
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
262-
for token_index in range(self.num_speculative_tokens - 1):
253+
for _ in range(self.num_speculative_tokens - 1):
263254
# Update the inputs.
264255
# cast to int32 is crucial when eagle model is compiled.
265256
# tensor.argmax() returns int64 by default.
@@ -327,21 +318,6 @@ def propose(
327318
hidden_states = hidden_states[:batch_size]
328319
logits = self.model.compute_logits(last_hidden_states[:batch_size],
329320
None)
330-
331-
if self.first_branching_level == token_index + 1:
332-
# Branching has occurred. The remaining tokens are drafted
333-
# using tree attention.
334-
draft_token_ids_list += self.propose_tree(
335-
tree_root_level=token_index + 1,
336-
batch_size=batch_size,
337-
logits=logits,
338-
positions=positions,
339-
hidden_states=hidden_states,
340-
common_attn_metadata=common_attn_metadata,
341-
)
342-
# [batch_size, num_tree_tokens]
343-
return torch.cat(draft_token_ids_list, dim=1)
344-
345321
draft_token_ids = logits.argmax(dim=-1)
346322
draft_token_ids_list.append(draft_token_ids)
347323

@@ -351,7 +327,6 @@ def propose(
351327

352328
def propose_tree(
353329
self,
354-
tree_root_level: int,
355330
batch_size: int,
356331
# [num_tokens, vocab_size]
357332
logits: torch.Tensor,
@@ -366,10 +341,10 @@ def propose_tree(
366341
assert isinstance(tree_attn_metadata_builder,
367342
TreeAttentionMetadataBuilder)
368343

369-
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
344+
total_num_drafts = self.cu_drafts_per_level[0]
370345
level_num_drafts = total_num_drafts
371346
# Sample a draft token for each child at the tree root level.
372-
num_children = self.child_drafts_per_level[tree_root_level]
347+
num_children = self.child_drafts_per_level[0]
373348
if num_children == 1:
374349
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
375350
else:
@@ -393,22 +368,23 @@ def propose_tree(
393368
positions.view(batch_size, -1) +
394369
self.tree_draft_pos_offsets[:batch_size, :])
395370
tree_depth = len(self.cu_drafts_per_level)
396-
for level in range(tree_root_level, tree_depth - 1):
371+
for level in range(tree_depth - 1):
397372
# Get draft positions for RoPE.
398373
draft_positions = positions + (level + 1)
399374
exceeds_max_model_len = (positions +
400375
total_num_drafts) >= self.max_model_len
401376
# Mask out the position ids that exceed the max model length.
402377
# Otherwise, we may get out-of-range error in RoPE.
403-
clamped_draft_positions = torch.where(
378+
draft_positions = torch.where(
404379
exceeds_max_model_len,
405380
0,
406381
draft_positions,
407-
)
382+
).view(batch_size, -1)
383+
408384
if level_num_drafts > 1:
409385
# Repeat the positions for each draft at this level.
410-
draft_positions = clamped_draft_positions.repeat_interleave(
411-
level_num_drafts).reshape(batch_size, -1)
386+
draft_positions = draft_positions.repeat_interleave(
387+
level_num_drafts, dim=1)
412388

413389
if num_children > 1:
414390
# Repeat draft hidden states for each child.
@@ -425,7 +401,7 @@ def propose_tree(
425401

426402
# Build new attention metadata for the next level of drafts.
427403
# This is necessary to support tree attention.
428-
query_len = total_num_drafts - tree_root_level
404+
query_len = total_num_drafts
429405
common_attn_metadata = replace(
430406
common_attn_metadata,
431407
query_start_loc=query_len * self.arange[:batch_size + 1],
@@ -435,7 +411,7 @@ def propose_tree(
435411
)
436412
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
437413
common_attn_metadata=common_attn_metadata,
438-
draft_index=tree_root_level + 1,
414+
draft_index=level + 1,
439415
)
440416

441417
# Apply new attention metadata to all layers.
@@ -516,7 +492,6 @@ def propose_tree(
516492
level_num_drafts = self.cu_drafts_per_level[level +
517493
1] - total_num_drafts
518494
total_num_drafts = self.cu_drafts_per_level[level + 1]
519-
520495
return draft_token_ids_list
521496

522497
def prepare_inputs(

0 commit comments

Comments
 (0)