Skip to content

Commit 7cd17e2

Browse files
authored
[Model][V1] Support Ernie MTP (#22169)
Signed-off-by: zhouchong <[email protected]> Co-authored-by: zhouchong <[email protected]>
1 parent 50df09f commit 7cd17e2

File tree

6 files changed

+320
-7
lines changed

6 files changed

+320
-7
lines changed

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,9 @@ def check_available_online(
556556
is_available_online=False,
557557
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
558558
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
559+
"ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
560+
trust_remote_code=True,
561+
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"),
559562
"Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5",
560563
speculative_model="zai-org/GLM-4.5",
561564
min_transformers_version="4.54",

vllm/config/__init__.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,8 @@ def get_layers_start_end_indices(
14631463
from vllm.distributed.utils import get_pp_indices
14641464
if (self.hf_text_config.model_type == "deepseek_mtp"
14651465
or self.hf_config.model_type == "mimo_mtp"
1466-
or self.hf_config.model_type == "glm4_moe_mtp"):
1466+
or self.hf_config.model_type == "glm4_moe_mtp"
1467+
or self.hf_config.model_type == "ernie_mtp"):
14671468
total_num_hidden_layers = getattr(self.hf_text_config,
14681469
"num_nextn_predict_layers", 0)
14691470
else:
@@ -1911,7 +1912,8 @@ def __post_init__(self):
19111912

19121913

19131914
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
1914-
"mlp_speculator", "draft_model", "deepseek_mtp"]
1915+
"mlp_speculator", "draft_model", "deepseek_mtp",
1916+
"ernie_mtp"]
19151917

19161918

19171919
@config
@@ -2044,6 +2046,16 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
20442046
"architectures": ["Glm4MoeMTPModel"]
20452047
})
20462048

2049+
if hf_config.model_type == "ernie4_5_moe":
2050+
hf_config.model_type = "ernie_mtp"
2051+
if hf_config.model_type == "ernie_mtp":
2052+
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
2053+
hf_config.update({
2054+
"n_predict": n_predict,
2055+
"architectures": ["ErnieMTPModel"]
2056+
})
2057+
return hf_config
2058+
20472059
return hf_config
20482060

20492061
def __post_init__(self):
@@ -2062,8 +2074,8 @@ def __post_init__(self):
20622074
if self.target_model_config and \
20632075
(self.target_model_config.hf_text_config.model_type \
20642076
== "deepseek_v3" or
2065-
self.target_model_config.hf_text_config.model_type \
2066-
== "mimo"):
2077+
self.target_model_config.hf_text_config.model_type in
2078+
("mimo","ernie4_5_moe")):
20672079
# use the draft model from the same model:
20682080
self.model = self.target_model_config.model
20692081
elif self.method in ("ngram", "[ngram]"):
@@ -2161,6 +2173,15 @@ def __post_init__(self):
21612173
"one layer. Might need some code changes " \
21622174
"to support multiple layers."
21632175
)
2176+
elif (self.draft_model_config.hf_config.model_type ==
2177+
"ernie_mtp"):
2178+
self.method = "ernie_mtp"
2179+
if self.num_speculative_tokens > 1:
2180+
logger.warning(
2181+
"All Ernie MTP models only have " \
2182+
"one layer. Might need some code changes " \
2183+
"to support multiple layers."
2184+
)
21642185
else:
21652186
self.method = "draft_model"
21662187
raise NotImplementedError(
@@ -2376,7 +2397,7 @@ def num_lookahead_slots(self) -> int:
23762397
return self.num_speculative_tokens
23772398

23782399
def use_eagle(self) -> bool:
2379-
return self.method in ("eagle", "eagle3", "deepseek_mtp")
2400+
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
23802401

23812402
def __repr__(self) -> str:
23822403
method = self.method
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# Copyright 2025 The Baidu team.
5+
# Copyright 2023 The vLLM team.
6+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7+
#
8+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9+
# and OPT implementations in this library. It has been modified from its
10+
# original forms to accommodate minor architectural differences compared
11+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12+
#
13+
# Licensed under the Apache License, Version 2.0 (the "License");
14+
# you may not use this file except in compliance with the License.
15+
# You may obtain a copy of the License at
16+
#
17+
# http://www.apache.org/licenses/LICENSE-2.0
18+
#
19+
# Unless required by applicable law or agreed to in writing, software
20+
# distributed under the License is distributed on an "AS IS" BASIS,
21+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22+
# See the License for the specific language governing permissions and
23+
# limitations under the License.
24+
"""Inference-only Ernie-MTP model."""
25+
from collections.abc import Iterable
26+
from typing import Optional
27+
28+
import torch
29+
import torch.nn as nn
30+
from transformers import PretrainedConfig
31+
32+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
33+
from vllm.model_executor.layers.layernorm import RMSNorm
34+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35+
from vllm.model_executor.layers.quantization import QuantizationConfig
36+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37+
from vllm.model_executor.layers.vocab_parallel_embedding import (
38+
ParallelLMHead, VocabParallelEmbedding)
39+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40+
from vllm.model_executor.sampling_metadata import SamplingMetadata
41+
from vllm.sequence import IntermediateTensors
42+
43+
from .interfaces import SupportsPP
44+
from .llama import LlamaDecoderLayer
45+
from .utils import is_pp_missing_parameter, maybe_prefix
46+
47+
48+
class ErnieMultiTokenPredictorLayer(nn.Module):
49+
50+
def __init__(
51+
self,
52+
config: PretrainedConfig,
53+
prefix: str,
54+
model_config: ModelConfig,
55+
cache_config: Optional[CacheConfig] = None,
56+
quant_config: Optional[QuantizationConfig] = None,
57+
) -> None:
58+
super().__init__()
59+
60+
self.mtp_emb_norm = RMSNorm(config.hidden_size,
61+
eps=config.rms_norm_eps)
62+
self.mtp_hidden_norm = RMSNorm(config.hidden_size,
63+
eps=config.rms_norm_eps)
64+
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
65+
config.hidden_size,
66+
bias=False)
67+
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
68+
prefix)
69+
70+
def forward(
71+
self,
72+
inputs_embeds: torch.Tensor,
73+
positions: torch.Tensor,
74+
previous_hidden_states: torch.Tensor,
75+
spec_step_index: int = 0,
76+
) -> torch.Tensor:
77+
assert inputs_embeds is not None
78+
# masking inputs at position 0, as not needed by MTP
79+
inputs_embeds[positions == 0] = 0
80+
81+
inputs_embeds = self.mtp_emb_norm(inputs_embeds)
82+
previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states)
83+
84+
hidden_states = self.mtp_linear_proj(
85+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
86+
87+
hidden_states, residual = self.mtp_block(positions=positions,
88+
hidden_states=hidden_states,
89+
residual=None)
90+
hidden_states = residual + hidden_states
91+
92+
return hidden_states
93+
94+
95+
class ErnieMultiTokenPredictor(nn.Module):
96+
97+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
98+
super().__init__()
99+
100+
config = vllm_config.model_config.hf_config
101+
self.mtp_start_layer_idx = config.num_hidden_layers
102+
self.num_mtp_layers = config.num_nextn_predict_layers
103+
# to map the exact layer index from weights
104+
self.layers = torch.nn.ModuleDict({
105+
str(idx):
106+
ErnieMultiTokenPredictorLayer(
107+
config,
108+
f"{prefix}.layers.{idx}",
109+
model_config=vllm_config.model_config,
110+
cache_config=vllm_config.cache_config,
111+
)
112+
for idx in range(self.mtp_start_layer_idx,
113+
self.mtp_start_layer_idx + self.num_mtp_layers)
114+
})
115+
self.embed_tokens = VocabParallelEmbedding(
116+
config.vocab_size,
117+
config.hidden_size,
118+
)
119+
self.logits_processor = LogitsProcessor(config.vocab_size)
120+
121+
def forward(
122+
self,
123+
input_ids: torch.Tensor,
124+
positions: torch.Tensor,
125+
previous_hidden_states: torch.Tensor,
126+
inputs_embeds: Optional[torch.Tensor] = None,
127+
spec_step_idx: int = 0,
128+
) -> torch.Tensor:
129+
if inputs_embeds is None:
130+
inputs_embeds = self.embed_tokens(input_ids)
131+
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
132+
inputs_embeds,
133+
positions,
134+
previous_hidden_states,
135+
spec_step_idx,
136+
)
137+
138+
def compute_logits(
139+
self,
140+
hidden_states: torch.Tensor,
141+
lm_head: ParallelLMHead,
142+
sampling_metadata: SamplingMetadata,
143+
spec_step_idx: int = 0,
144+
) -> torch.Tensor:
145+
self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
146+
logits = self.logits_processor(lm_head, hidden_states,
147+
sampling_metadata)
148+
return logits
149+
150+
151+
class ErnieMTP(nn.Module, SupportsPP):
152+
153+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
154+
super().__init__()
155+
156+
self.config = vllm_config.model_config.hf_config
157+
self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config,
158+
prefix=maybe_prefix(
159+
prefix, "model"))
160+
self.lm_head = ParallelLMHead(self.config.vocab_size,
161+
self.config.hidden_size)
162+
self.sampler = get_sampler()
163+
164+
if self.config.tie_word_embeddings:
165+
self.lm_head.weight = self.model.embed_tokens.weight
166+
167+
def forward(
168+
self,
169+
input_ids: torch.Tensor,
170+
positions: torch.Tensor,
171+
hidden_states: torch.Tensor,
172+
intermediate_tensors: Optional[IntermediateTensors] = None,
173+
inputs_embeds: Optional[torch.Tensor] = None,
174+
spec_step_idx: int = 0,
175+
) -> torch.Tensor:
176+
assert spec_step_idx == 0, "ernie_mtp only support predict one token"
177+
hidden_states = self.model(input_ids, positions, hidden_states,
178+
inputs_embeds, spec_step_idx)
179+
return hidden_states
180+
181+
def compute_logits(
182+
self,
183+
hidden_states: torch.Tensor,
184+
sampling_metadata: SamplingMetadata,
185+
spec_step_idx: int = 0,
186+
) -> Optional[torch.Tensor]:
187+
return self.model.compute_logits(hidden_states, self.lm_head,
188+
sampling_metadata, spec_step_idx)
189+
190+
def sample(
191+
self,
192+
logits: torch.Tensor,
193+
sampling_metadata: SamplingMetadata,
194+
) -> Optional[SamplerOutput]:
195+
next_tokens = self.sampler(logits, sampling_metadata)
196+
return next_tokens
197+
198+
def load_weights(self, weights: Iterable[tuple[str,
199+
torch.Tensor]]) -> set[str]:
200+
stacked_params_mapping = [
201+
("qkv_proj", "q_proj", "q"),
202+
("qkv_proj", "k_proj", "k"),
203+
("qkv_proj", "v_proj", "v"),
204+
("gate_up_proj", "gate_proj", 0),
205+
("gate_up_proj", "up_proj", 1),
206+
]
207+
208+
params_dict = dict(self.named_parameters())
209+
loaded_params: set[str] = set()
210+
for name, loaded_weight in weights:
211+
212+
if self.config.tie_word_embeddings and name.endswith(
213+
"lm_head.weight"):
214+
continue
215+
if "rotary_emb.inv_freq" in name:
216+
continue
217+
if "mtp" in name:
218+
name = self._rewrite_spec_layer_name(self.config, name)
219+
220+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
221+
# Skip non-stacked layers and experts (experts handled below).
222+
if weight_name not in name:
223+
continue
224+
if "mtp" not in name:
225+
continue
226+
# We have mlp.experts[0].gate_proj in the checkpoint.
227+
# Since we handle the experts below in expert_params_mapping,
228+
# we need to skip here BEFORE we update the name, otherwise
229+
# name will be updated to mlp.experts[0].gate_up_proj, which
230+
# will then be updated below in expert_params_mapping
231+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
232+
if (("mlp.experts." in name) and name not in params_dict):
233+
continue
234+
name = name.replace(weight_name, param_name)
235+
# Skip loading extra bias for GPTQ models.
236+
if ((name.endswith(".bias") or name.endswith("_bias"))
237+
and name not in params_dict):
238+
continue
239+
# Skip layers on other devices.
240+
if is_pp_missing_parameter(name, self):
241+
continue
242+
243+
param = params_dict[name]
244+
weight_loader = param.weight_loader
245+
weight_loader(param, loaded_weight, shard_id)
246+
break
247+
else:
248+
# Skip loading extra bias for GPTQ models.
249+
if ((name.endswith(".bias") or name.endswith("_bias"))
250+
and name not in params_dict):
251+
continue
252+
# Skip layers on other devices.
253+
if is_pp_missing_parameter(name, self):
254+
continue
255+
256+
# According to DeepSeek-V3 Technical Report, MTP modules
257+
# shares embedding layer. We only load the first weights.
258+
if "mtp_" not in name and ("embed_tokens" not in name
259+
and "lm_head" not in name):
260+
continue
261+
262+
param = params_dict[name]
263+
weight_loader = getattr(param, "weight_loader",
264+
default_weight_loader)
265+
weight_loader(param, loaded_weight)
266+
loaded_params.add(name)
267+
return loaded_params
268+
269+
def _rewrite_spec_layer_name(self, config: PretrainedConfig,
270+
name: str) -> str:
271+
"""
272+
Rewrite the weight name to match the format of the original model.
273+
"""
274+
spec_layer_weight_names = [
275+
"embed_tokens", "mtp_emb_norm", "mtp_hidden_norm",
276+
"mtp_linear_proj"
277+
]
278+
layer_idx = config.num_hidden_layers
279+
for weight_name in spec_layer_weight_names:
280+
if weight_name in name:
281+
name = name.replace(
282+
f"model.{weight_name}.0.",
283+
f"model.layers.{layer_idx}.{weight_name}.")
284+
return name
285+
name = name.replace("model.mtp_block.0.",
286+
f"model.layers.{layer_idx}.mtp_block.")
287+
return name

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@
266266
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
267267
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
268268
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
269+
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
269270
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
270271
"MedusaModel": ("medusa", "Medusa"),
271272
# Temporarily disabled.

vllm/v1/spec_decode/eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def propose(
194194
hidden_states=self.hidden_states[:num_input_tokens],
195195
inputs_embeds=inputs_embeds,
196196
)
197-
if self.method == "deepseek_mtp":
197+
if self.method in ("deepseek_mtp", "ernie_mtp"):
198198
last_hidden_states = ret_hidden_states
199199
else:
200200
last_hidden_states, hidden_states = ret_hidden_states

vllm/worker/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def __init__(
7777
"eagle",
7878
"deepseek_mtp",
7979
"glm4_moe_mtp",
80-
"mimo_mtp")) \
80+
"mimo_mtp",
81+
"ernie_mtp")) \
8182
else {"return_hidden_states": True}
8283

8384
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner

0 commit comments

Comments
 (0)