Skip to content

Commit e8224f3

Browse files
[V1][Spec Decode] Eagle Model loading (#16035)
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
1 parent 9665313 commit e8224f3

File tree

9 files changed

+253
-30
lines changed

9 files changed

+253
-30
lines changed

examples/offline_inference/eagle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def main():
7676
max_num_seqs=args.max_num_seqs,
7777
gpu_memory_utilization=0.8,
7878
speculative_config={
79+
"method": "eagle",
7980
"model": eagle_dir,
8081
"num_speculative_tokens": args.num_spec_tokens,
8182
"draft_tensor_parallel_size": args.draft_tp,

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ def check_available_online(
374374
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
375375
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
376376
trust_remote_code=True),
377+
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
378+
trust_remote_code=True,
379+
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
380+
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
377381
}
378382

379383
_TRANSFORMERS_MODELS = {

tests/v1/e2e/test_ngram_spec_decode.py renamed to tests/v1/e2e/test_spec_decode.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def model_name():
5353
return "meta-llama/Meta-Llama-3-8B-Instruct"
5454

5555

56+
@pytest.fixture
57+
def eagle_model_name():
58+
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
59+
60+
5661
def test_ngram_correctness(
5762
monkeypatch: pytest.MonkeyPatch,
5863
test_prompts: list[list[dict[str, Any]]],
@@ -95,3 +100,47 @@ def test_ngram_correctness(
95100
# Upon failure, inspect the outputs to check for inaccuracy.
96101
assert matches > int(0.7 * len(ref_outputs))
97102
del spec_llm
103+
104+
105+
def test_eagle_correctness(
106+
monkeypatch: pytest.MonkeyPatch,
107+
test_prompts: list[list[dict[str, Any]]],
108+
sampling_config: SamplingParams,
109+
model_name: str,
110+
eagle_model_name: str,
111+
):
112+
'''
113+
Compare the outputs of a original LLM and a speculative LLM
114+
should be the same when using eagle speculative decoding.
115+
'''
116+
with monkeypatch.context() as m:
117+
m.setenv("VLLM_USE_V1", "1")
118+
119+
ref_llm = LLM(model=model_name, max_model_len=1024)
120+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
121+
del ref_llm
122+
123+
spec_llm = LLM(
124+
model=model_name,
125+
speculative_config={
126+
"method": "eagle",
127+
"model": eagle_model_name,
128+
"num_speculative_tokens": 3,
129+
},
130+
max_model_len=1024,
131+
)
132+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
133+
matches = 0
134+
misses = 0
135+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
136+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
137+
matches += 1
138+
else:
139+
misses += 1
140+
print(f"ref_output: {ref_output.outputs[0].text}")
141+
print(f"spec_output: {spec_output.outputs[0].text}")
142+
143+
# Heuristic: expect at least 70% of the prompts to match exactly
144+
# Upon failure, inspect the outputs to check for inaccuracy.
145+
assert matches > int(0.7 * len(ref_outputs))
146+
del spec_llm

vllm/model_executor/model_loader/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _hpu_weights_iterator(iterator: Generator):
414414
return ((source.prefix + name, tensor)
415415
for (name, tensor) in weights_iterator)
416416

417-
def _get_all_weights(
417+
def get_all_weights(
418418
self,
419419
model_config: ModelConfig,
420420
model: nn.Module,
@@ -453,7 +453,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
453453

454454
weights_to_load = {name for name, _ in model.named_parameters()}
455455
loaded_weights = model.load_weights(
456-
self._get_all_weights(model_config, model))
456+
self.get_all_weights(model_config, model))
457457
self.counter_after_loading_weights = time.perf_counter()
458458
logger.info(
459459
"Loading weights took %.2f seconds",
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Iterable, Set, Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
from transformers import LlamaConfig
8+
9+
from vllm.config import ModelConfig
10+
from vllm.logger import init_logger
11+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
12+
from vllm.model_executor.layers.vocab_parallel_embedding import (
13+
VocabParallelEmbedding)
14+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
15+
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
16+
LlamaForCausalLM)
17+
18+
from .utils import AutoWeightsLoader, maybe_prefix
19+
20+
logger = init_logger(__name__)
21+
22+
23+
class LlamaDecoderLayer(LlamaDecoderLayer):
24+
25+
def __init__(
26+
self,
27+
config: LlamaConfig,
28+
disable_input_layernorm: bool,
29+
prefix: str = "",
30+
) -> None:
31+
super().__init__(config, prefix=prefix)
32+
33+
# Skip the input_layernorm
34+
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
35+
if disable_input_layernorm:
36+
del self.input_layernorm
37+
self.input_layernorm = nn.Identity()
38+
39+
40+
class LlamaModel(nn.Module):
41+
42+
def __init__(
43+
self,
44+
*,
45+
model_config: ModelConfig,
46+
start_layer_id: int = 0,
47+
prefix: str = "",
48+
) -> None:
49+
super().__init__()
50+
self.config = model_config.hf_config
51+
self.vocab_size = self.config.vocab_size
52+
self.embed_tokens = VocabParallelEmbedding(
53+
self.config.vocab_size,
54+
self.config.hidden_size,
55+
prefix=maybe_prefix(prefix, "embed_tokens"),
56+
)
57+
self.layers = nn.ModuleList([
58+
LlamaDecoderLayer(
59+
self.config,
60+
i == 0,
61+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
62+
) for i in range(self.config.num_hidden_layers)
63+
])
64+
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
65+
self.config.hidden_size,
66+
bias=False)
67+
68+
def forward(
69+
self,
70+
input_ids: torch.Tensor,
71+
positions: torch.Tensor,
72+
hidden_states: torch.Tensor,
73+
) -> torch.Tensor:
74+
input_embeds = self.embed_tokens(input_ids)
75+
hidden_states = self.fc(
76+
torch.cat((input_embeds, hidden_states), dim=-1))
77+
residual = None
78+
for i in range(len(self.layers)):
79+
layer = self.layers[i]
80+
hidden_states, residual = layer(
81+
positions,
82+
hidden_states,
83+
residual,
84+
)
85+
return hidden_states + residual
86+
87+
def load_weights(self, weights: Iterable[Tuple[str,
88+
torch.Tensor]]) -> Set[str]:
89+
stacked_params_mapping = [
90+
# (param_name, shard_name, shard_id)
91+
(".qkv_proj", ".q_proj", "q"),
92+
(".qkv_proj", ".k_proj", "k"),
93+
(".qkv_proj", ".v_proj", "v"),
94+
(".gate_up_proj", ".gate_proj", 0),
95+
(".gate_up_proj", ".up_proj", 1),
96+
]
97+
params_dict = dict(self.named_parameters())
98+
loaded_params: Set[str] = set()
99+
for name, loaded_weight in weights:
100+
for param_name, weight_name, shard_id in stacked_params_mapping:
101+
if weight_name not in name:
102+
continue
103+
name = name.replace(weight_name, param_name)
104+
param = params_dict[name]
105+
weight_loader = param.weight_loader
106+
weight_loader(param, loaded_weight, shard_id)
107+
break
108+
else:
109+
param = params_dict[name]
110+
weight_loader = getattr(param, "weight_loader",
111+
default_weight_loader)
112+
weight_loader(param, loaded_weight)
113+
loaded_params.add(name)
114+
return loaded_params
115+
116+
117+
class EagleLlamaForCausalLM(LlamaForCausalLM):
118+
119+
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
120+
nn.Module.__init__(self)
121+
self.config = model_config.hf_config
122+
self.model = LlamaModel(model_config=model_config,
123+
start_layer_id=start_layer_id,
124+
prefix="model")
125+
126+
logit_scale = getattr(self.config, "logit_scale", 1.0)
127+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
128+
scale=logit_scale)
129+
130+
def forward(
131+
self,
132+
input_ids: torch.Tensor,
133+
positions: torch.Tensor,
134+
hidden_states: torch.Tensor,
135+
) -> torch.Tensor:
136+
return self.model(input_ids, positions, hidden_states)
137+
138+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
139+
loader = AutoWeightsLoader(
140+
self,
141+
skip_prefixes=(["lm_head."]
142+
if self.config.tie_word_embeddings else None),
143+
)
144+
145+
model_weights = {}
146+
for name, loaded_weight in weights:
147+
if "lm_head" not in name:
148+
name = "model." + name
149+
model_weights[name] = loaded_weight
150+
151+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@
206206

207207
_SPECULATIVE_DECODING_MODELS = {
208208
"EAGLEModel": ("eagle", "EAGLE"),
209+
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
209210
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
210211
"MedusaModel": ("medusa", "Medusa"),
211212
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),

vllm/transformers_utils/configs/eagle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from transformers import AutoConfig, PretrainedConfig
77

8+
import vllm.envs as envs
89
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
910

1011

@@ -41,8 +42,10 @@ def __init__(self,
4142
self.truncated_vocab_size = self.model.vocab_size if \
4243
truncated_vocab_size is None else truncated_vocab_size
4344

44-
if "architectures" not in kwargs:
45+
if not envs.VLLM_USE_V1:
4546
kwargs["architectures"] = ["EAGLEModel"]
47+
else:
48+
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
4649

4750
super().__init__(**kwargs)
4851

vllm/v1/spec_decode/eagle.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import triton
55
import triton.language as tl
66

7-
from vllm.config import VllmConfig
7+
from vllm.config import VllmConfig, set_current_vllm_config
88
from vllm.forward_context import set_forward_context
9+
from vllm.model_executor.model_loader.loader import get_model_loader
10+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
11+
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
912
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
1013
from vllm.v1.sample.metadata import SamplingMetadata
1114

@@ -21,8 +24,12 @@ def __init__(
2124
self.num_speculative_tokens = (
2225
vllm_config.speculative_config.num_speculative_tokens)
2326
self.block_size = vllm_config.cache_config.block_size
24-
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
25-
device=device)
27+
# We need +1 here because the arange is used to set query_start_loc,
28+
# which has one more element than batch_size.
29+
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
30+
1,
31+
device=device,
32+
dtype=torch.int32)
2633

2734
def propose(
2835
self,
@@ -54,7 +61,9 @@ def propose(
5461
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
5562
input_ids[last_token_indices] = next_token_ids
5663

57-
seq_lens = target_positions[last_token_indices] + 1
64+
# FA requires seq_len to have dtype int32.
65+
seq_lens = (target_positions[last_token_indices] + 1).int()
66+
5867
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
5968
max_seq_len = seq_lens.max().item()
6069
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
@@ -98,7 +107,7 @@ def propose(
98107
hidden_states = sample_hidden_states
99108
attn_metadata.num_actual_tokens = batch_size
100109
attn_metadata.max_query_len = 1
101-
attn_metadata.query_start_loc = self.arange[:batch_size]
110+
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
102111
for _ in range(self.num_speculative_tokens - 1):
103112
# Update the inputs.
104113
input_ids = draft_token_ids_list[-1]
@@ -176,26 +185,28 @@ def prepare_inputs(
176185
return cu_num_tokens, token_indices
177186

178187
def load_model(self, target_model: nn.Module) -> None:
179-
self.model = DummyEagleModel()
180-
self.model.get_input_embeddings = target_model.get_input_embeddings
181-
self.model.compute_logits = target_model.compute_logits
182-
183-
184-
# FIXME(woosuk): This is a dummy model for testing.
185-
# Remove this once we have a real model.
186-
class DummyEagleModel(nn.Module):
187-
188-
def __init__(self):
189-
super().__init__()
190-
191-
def forward(
192-
self,
193-
input_ids: torch.Tensor,
194-
hidden_states: torch.Tensor,
195-
positions: torch.Tensor,
196-
) -> torch.Tensor:
197-
input_embeddings = self.get_input_embeddings(input_ids)
198-
return hidden_states + input_embeddings # Dummy return.
188+
loader = get_model_loader(self.vllm_config.load_config)
189+
target_layer_num = self.vllm_config.model_config.get_num_layers(
190+
self.vllm_config.parallel_config)
191+
192+
draft_model_config = \
193+
self.vllm_config.speculative_config.draft_model_config
194+
# FIXME(lily): This does not handle with distributed inference.
195+
target_device = self.vllm_config.device_config.device
196+
# We need to set the vllm_config here to register attention
197+
# layers in the forward context.
198+
with set_default_torch_dtype(
199+
draft_model_config.dtype), set_current_vllm_config(
200+
self.vllm_config):
201+
self.model = EagleLlamaForCausalLM(
202+
model_config=draft_model_config,
203+
start_layer_id=target_layer_num).to(target_device)
204+
205+
self.model.load_weights(
206+
loader.get_all_weights(
207+
self.vllm_config.speculative_config.draft_model_config,
208+
self.model))
209+
self.model.lm_head = target_model.lm_head
199210

200211

201212
# FIXME(woosuk): The logic here is duplicated with the main sampling code.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,9 +1191,12 @@ def execute_model(
11911191

11921192
if spec_decode_metadata is None:
11931193
# input_ids can be None for multimodal models.
1194+
# We need to slice token_ids, positions, and hidden_states
1195+
# because the eagle head does not use cuda graph and should
1196+
# not include padding.
11941197
target_token_ids = self.input_ids[:num_scheduled_tokens]
1195-
target_positions = positions
1196-
target_hidden_states = hidden_states
1198+
target_positions = positions[:num_scheduled_tokens]
1199+
target_hidden_states = hidden_states[:num_scheduled_tokens]
11971200
target_slot_mapping = attn_metadata.slot_mapping
11981201
cu_num_tokens = attn_metadata.query_start_loc
11991202
else:

0 commit comments

Comments
 (0)