|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
| 4 | +from typing import Optional |
4 | 5 | from unittest import mock
|
5 | 6 |
|
6 | 7 | import pytest
|
|
23 | 24 | eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
24 | 25 |
|
25 | 26 |
|
26 |
| -def _create_proposer(method: str, k: int) -> EagleProposer: |
| 27 | +def _create_proposer( |
| 28 | + method: str, |
| 29 | + num_speculative_tokens: int, |
| 30 | + speculative_token_tree: Optional[list[tuple[int]]] = None, |
| 31 | +) -> EagleProposer: |
27 | 32 | model_config = ModelConfig(model=model_dir,
|
28 | 33 | runner="generate",
|
29 | 34 | max_model_len=100)
|
30 | 35 |
|
31 | 36 | # Choose model directory based on method
|
32 | 37 | draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
|
33 | 38 |
|
| 39 | + spec_token_tree_str = None |
| 40 | + if speculative_token_tree is not None: |
| 41 | + assert num_speculative_tokens == len(speculative_token_tree) |
| 42 | + spec_token_tree_str = str(speculative_token_tree) |
| 43 | + |
34 | 44 | speculative_config = SpeculativeConfig(
|
35 | 45 | target_model_config=model_config,
|
36 | 46 | target_parallel_config=ParallelConfig(),
|
37 | 47 | model=draft_model_dir,
|
38 | 48 | method=method,
|
39 |
| - num_speculative_tokens=k, |
| 49 | + num_speculative_tokens=num_speculative_tokens, |
| 50 | + speculative_token_tree=spec_token_tree_str, |
40 | 51 | )
|
41 | 52 |
|
42 | 53 | vllm_config = VllmConfig(
|
@@ -189,7 +200,7 @@ class _TargetModelStub(LlamaForCausalLM):
|
189 | 200 | target_model.lm_head = mock.MagicMock()
|
190 | 201 |
|
191 | 202 | # Create proposer using the helper function
|
192 |
| - proposer = _create_proposer(method, k=8) |
| 203 | + proposer = _create_proposer(method, num_speculative_tokens=8) |
193 | 204 |
|
194 | 205 | # Call the method under test
|
195 | 206 | proposer.load_model(target_model)
|
@@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
226 | 237 | pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
227 | 238 | "multi-token eagle spec decode on current platform")
|
228 | 239 |
|
| 240 | + if (attn_backend == "TREE_ATTN"): |
| 241 | + pytest.skip("TREE_ATTN is tested separately in test_propose_tree" |
| 242 | + "because it requires special input mocking.") |
| 243 | + |
229 | 244 | if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
230 | 245 | monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
231 | 246 |
|
@@ -378,3 +393,142 @@ def create_deterministic_logits(token_ids):
|
378 | 393 |
|
379 | 394 | # Verify all tokens match our expectations
|
380 | 395 | assert torch.equal(result, expected_tokens)
|
| 396 | + |
| 397 | + |
| 398 | +@pytest.mark.parametrize( |
| 399 | + "spec_token_tree", |
| 400 | + [ |
| 401 | + [(0, )], # A single token |
| 402 | + [(0, ), (0, 0), (0, 0, 0)], # Chain |
| 403 | + [(0, ), (1, ), (2, )], # Parallel |
| 404 | + [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), |
| 405 | + (2, 1)], # Tree |
| 406 | + ]) |
| 407 | +def test_propose_tree(spec_token_tree): |
| 408 | + # Get GPU device. |
| 409 | + device = torch.device(current_platform.device_type) |
| 410 | + |
| 411 | + # Setup test parameters. |
| 412 | + batch_size = 2 |
| 413 | + seq_len_1 = 5 |
| 414 | + seq_len_2 = 3 |
| 415 | + total_tokens = seq_len_1 + seq_len_2 |
| 416 | + vocab_size = 100 |
| 417 | + seq_lens = [seq_len_1, seq_len_2] |
| 418 | + num_speculative_tokens = len(spec_token_tree) |
| 419 | + |
| 420 | + # Create proposer first so we can use its actual hidden_size. |
| 421 | + proposer = _create_proposer("eagle", |
| 422 | + num_speculative_tokens, |
| 423 | + speculative_token_tree=spec_token_tree) |
| 424 | + # Get the hidden_size from the proposer to ensure consistency. |
| 425 | + hidden_size = proposer.hidden_size |
| 426 | + |
| 427 | + # Helper to create deterministic logits that will produce specific tokens |
| 428 | + def create_deterministic_logits(token_ids, k: int): |
| 429 | + logits = torch.full((batch_size, vocab_size), -100.0, device=device) |
| 430 | + for i, token_id in enumerate(token_ids): |
| 431 | + # Assign decreasing values to the k, consecutive, tokens. |
| 432 | + for j in range(k): |
| 433 | + logits[i, token_id + j] = 100.0 - j |
| 434 | + return logits |
| 435 | + |
| 436 | + # Mock a model that returns deterministic logits. |
| 437 | + base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device) |
| 438 | + |
| 439 | + # Skip loading the model and replace it with a mock that returns |
| 440 | + # deterministic outputs. |
| 441 | + model_mock = mock.MagicMock() |
| 442 | + |
| 443 | + # Mock the model forward calls. |
| 444 | + forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), |
| 445 | + torch.zeros(total_tokens, hidden_size, device=device))] |
| 446 | + for cu_num_drafts in proposer.cu_drafts_per_level: |
| 447 | + h_logits = torch.zeros(batch_size * cu_num_drafts, |
| 448 | + hidden_size, |
| 449 | + device=device) |
| 450 | + h_states = torch.zeros(batch_size * cu_num_drafts, |
| 451 | + hidden_size, |
| 452 | + device=device) |
| 453 | + forward_returns.append((h_logits, h_states)) |
| 454 | + model_mock.side_effect = forward_returns |
| 455 | + |
| 456 | + # Mock the compute_logits calls. |
| 457 | + cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, |
| 458 | + dtype=torch.int32, |
| 459 | + device=device) |
| 460 | + logits_returns = [] |
| 461 | + for level, num_children in enumerate(proposer.child_drafts_per_level): |
| 462 | + token_ids = base_token_ids + cu_num_drafts_tensor[level] |
| 463 | + level_num_drafts = cu_num_drafts_tensor[ |
| 464 | + level + 1] - cu_num_drafts_tensor[level] |
| 465 | + level_logits = [] |
| 466 | + for i in range(level_num_drafts // num_children): |
| 467 | + level_logits.append( |
| 468 | + create_deterministic_logits(token_ids + i * num_children, |
| 469 | + num_children)) |
| 470 | + logits_returns.append(torch.stack(level_logits, dim=1)) |
| 471 | + model_mock.compute_logits.side_effect = logits_returns |
| 472 | + |
| 473 | + # Assign the mock to the proposer |
| 474 | + proposer.model = model_mock |
| 475 | + |
| 476 | + # Assign draft attn_layer_names since load_model is not invoked |
| 477 | + proposer.attn_layer_names = ["layer.0"] |
| 478 | + |
| 479 | + # Get the tree attention metadata builder. |
| 480 | + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) |
| 481 | + attn_metadata_builder = attn_metadata_builder_cls( |
| 482 | + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), |
| 483 | + layer_names=proposer.attn_layer_names, |
| 484 | + vllm_config=proposer.vllm_config, |
| 485 | + device=device, |
| 486 | + ) |
| 487 | + |
| 488 | + # Mock runner for attention metadata building. |
| 489 | + proposer.runner = mock.MagicMock() |
| 490 | + proposer.runner.attn_groups.append([mock.MagicMock()]) |
| 491 | + proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder |
| 492 | + |
| 493 | + # Setup inputs for the proposer. |
| 494 | + target_token_ids = torch.randint(0, |
| 495 | + vocab_size, (total_tokens, ), |
| 496 | + device=device) |
| 497 | + target_positions = torch.cat([ |
| 498 | + torch.arange(seq_len_1, device=device), |
| 499 | + torch.arange(seq_len_2, device=device) |
| 500 | + ]) |
| 501 | + target_hidden_states = torch.randn(total_tokens, |
| 502 | + hidden_size, |
| 503 | + device=device) |
| 504 | + next_token_ids = torch.randint(0, |
| 505 | + vocab_size, (batch_size, ), |
| 506 | + dtype=torch.int32, |
| 507 | + device=device) |
| 508 | + batch_spec = BatchSpec( |
| 509 | + seq_lens=seq_lens, |
| 510 | + query_lens=seq_lens, |
| 511 | + ) |
| 512 | + common_attn_metadata = create_common_attn_metadata( |
| 513 | + batch_spec, |
| 514 | + block_size=16, |
| 515 | + device=device, |
| 516 | + ) |
| 517 | + sampling_metadata = mock.MagicMock() |
| 518 | + |
| 519 | + # Propose draft tokens. |
| 520 | + result = proposer.propose(target_token_ids=target_token_ids, |
| 521 | + target_positions=target_positions, |
| 522 | + target_hidden_states=target_hidden_states, |
| 523 | + next_token_ids=next_token_ids, |
| 524 | + common_attn_metadata=common_attn_metadata, |
| 525 | + sampling_metadata=sampling_metadata) |
| 526 | + assert result.shape == (batch_size, num_speculative_tokens) |
| 527 | + |
| 528 | + # The tokens are expected to be consecutive integers starting |
| 529 | + # from the base token IDs. |
| 530 | + expected_tokens = base_token_ids[:, None] + torch.arange( |
| 531 | + num_speculative_tokens, dtype=torch.int64, device=device) |
| 532 | + |
| 533 | + # Verify that the draft tokens match our expectations. |
| 534 | + assert torch.equal(result, expected_tokens) |
0 commit comments