|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +# Copyright 2025 The Baidu team. |
| 5 | +# Copyright 2023 The vLLM team. |
| 6 | +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. |
| 7 | +# |
| 8 | +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX |
| 9 | +# and OPT implementations in this library. It has been modified from its |
| 10 | +# original forms to accommodate minor architectural differences compared |
| 11 | +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
| 12 | +# |
| 13 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 14 | +# you may not use this file except in compliance with the License. |
| 15 | +# You may obtain a copy of the License at |
| 16 | +# |
| 17 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 18 | +# |
| 19 | +# Unless required by applicable law or agreed to in writing, software |
| 20 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 21 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 22 | +# See the License for the specific language governing permissions and |
| 23 | +# limitations under the License. |
| 24 | +"""Inference-only Ernie-MTP model.""" |
| 25 | +from collections.abc import Iterable |
| 26 | +from typing import Optional |
| 27 | + |
| 28 | +import torch |
| 29 | +import torch.nn as nn |
| 30 | +from transformers import PretrainedConfig |
| 31 | + |
| 32 | +from vllm.config import CacheConfig, ModelConfig, VllmConfig |
| 33 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 34 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 35 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 36 | +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
| 37 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 38 | + ParallelLMHead, VocabParallelEmbedding) |
| 39 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 40 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 41 | +from vllm.sequence import IntermediateTensors |
| 42 | + |
| 43 | +from .interfaces import SupportsPP |
| 44 | +from .llama import LlamaDecoderLayer |
| 45 | +from .utils import is_pp_missing_parameter, maybe_prefix |
| 46 | + |
| 47 | + |
| 48 | +class ErnieMultiTokenPredictorLayer(nn.Module): |
| 49 | + |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + config: PretrainedConfig, |
| 53 | + prefix: str, |
| 54 | + model_config: ModelConfig, |
| 55 | + cache_config: Optional[CacheConfig] = None, |
| 56 | + quant_config: Optional[QuantizationConfig] = None, |
| 57 | + ) -> None: |
| 58 | + super().__init__() |
| 59 | + |
| 60 | + self.mtp_emb_norm = RMSNorm(config.hidden_size, |
| 61 | + eps=config.rms_norm_eps) |
| 62 | + self.mtp_hidden_norm = RMSNorm(config.hidden_size, |
| 63 | + eps=config.rms_norm_eps) |
| 64 | + self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, |
| 65 | + config.hidden_size, |
| 66 | + bias=False) |
| 67 | + self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, |
| 68 | + prefix) |
| 69 | + |
| 70 | + def forward( |
| 71 | + self, |
| 72 | + inputs_embeds: torch.Tensor, |
| 73 | + positions: torch.Tensor, |
| 74 | + previous_hidden_states: torch.Tensor, |
| 75 | + spec_step_index: int = 0, |
| 76 | + ) -> torch.Tensor: |
| 77 | + assert inputs_embeds is not None |
| 78 | + # masking inputs at position 0, as not needed by MTP |
| 79 | + inputs_embeds[positions == 0] = 0 |
| 80 | + |
| 81 | + inputs_embeds = self.mtp_emb_norm(inputs_embeds) |
| 82 | + previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) |
| 83 | + |
| 84 | + hidden_states = self.mtp_linear_proj( |
| 85 | + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) |
| 86 | + |
| 87 | + hidden_states, residual = self.mtp_block(positions=positions, |
| 88 | + hidden_states=hidden_states, |
| 89 | + residual=None) |
| 90 | + hidden_states = residual + hidden_states |
| 91 | + |
| 92 | + return hidden_states |
| 93 | + |
| 94 | + |
| 95 | +class ErnieMultiTokenPredictor(nn.Module): |
| 96 | + |
| 97 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 98 | + super().__init__() |
| 99 | + |
| 100 | + config = vllm_config.model_config.hf_config |
| 101 | + self.mtp_start_layer_idx = config.num_hidden_layers |
| 102 | + self.num_mtp_layers = config.num_nextn_predict_layers |
| 103 | + # to map the exact layer index from weights |
| 104 | + self.layers = torch.nn.ModuleDict({ |
| 105 | + str(idx): |
| 106 | + ErnieMultiTokenPredictorLayer( |
| 107 | + config, |
| 108 | + f"{prefix}.layers.{idx}", |
| 109 | + model_config=vllm_config.model_config, |
| 110 | + cache_config=vllm_config.cache_config, |
| 111 | + ) |
| 112 | + for idx in range(self.mtp_start_layer_idx, |
| 113 | + self.mtp_start_layer_idx + self.num_mtp_layers) |
| 114 | + }) |
| 115 | + self.embed_tokens = VocabParallelEmbedding( |
| 116 | + config.vocab_size, |
| 117 | + config.hidden_size, |
| 118 | + ) |
| 119 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
| 120 | + |
| 121 | + def forward( |
| 122 | + self, |
| 123 | + input_ids: torch.Tensor, |
| 124 | + positions: torch.Tensor, |
| 125 | + previous_hidden_states: torch.Tensor, |
| 126 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 127 | + spec_step_idx: int = 0, |
| 128 | + ) -> torch.Tensor: |
| 129 | + if inputs_embeds is None: |
| 130 | + inputs_embeds = self.embed_tokens(input_ids) |
| 131 | + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( |
| 132 | + inputs_embeds, |
| 133 | + positions, |
| 134 | + previous_hidden_states, |
| 135 | + spec_step_idx, |
| 136 | + ) |
| 137 | + |
| 138 | + def compute_logits( |
| 139 | + self, |
| 140 | + hidden_states: torch.Tensor, |
| 141 | + lm_head: ParallelLMHead, |
| 142 | + sampling_metadata: SamplingMetadata, |
| 143 | + spec_step_idx: int = 0, |
| 144 | + ) -> torch.Tensor: |
| 145 | + self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] |
| 146 | + logits = self.logits_processor(lm_head, hidden_states, |
| 147 | + sampling_metadata) |
| 148 | + return logits |
| 149 | + |
| 150 | + |
| 151 | +class ErnieMTP(nn.Module, SupportsPP): |
| 152 | + |
| 153 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 154 | + super().__init__() |
| 155 | + |
| 156 | + self.config = vllm_config.model_config.hf_config |
| 157 | + self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, |
| 158 | + prefix=maybe_prefix( |
| 159 | + prefix, "model")) |
| 160 | + self.lm_head = ParallelLMHead(self.config.vocab_size, |
| 161 | + self.config.hidden_size) |
| 162 | + self.sampler = get_sampler() |
| 163 | + |
| 164 | + if self.config.tie_word_embeddings: |
| 165 | + self.lm_head.weight = self.model.embed_tokens.weight |
| 166 | + |
| 167 | + def forward( |
| 168 | + self, |
| 169 | + input_ids: torch.Tensor, |
| 170 | + positions: torch.Tensor, |
| 171 | + hidden_states: torch.Tensor, |
| 172 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 173 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 174 | + spec_step_idx: int = 0, |
| 175 | + ) -> torch.Tensor: |
| 176 | + assert spec_step_idx == 0, "ernie_mtp only support predict one token" |
| 177 | + hidden_states = self.model(input_ids, positions, hidden_states, |
| 178 | + inputs_embeds, spec_step_idx) |
| 179 | + return hidden_states |
| 180 | + |
| 181 | + def compute_logits( |
| 182 | + self, |
| 183 | + hidden_states: torch.Tensor, |
| 184 | + sampling_metadata: SamplingMetadata, |
| 185 | + spec_step_idx: int = 0, |
| 186 | + ) -> Optional[torch.Tensor]: |
| 187 | + return self.model.compute_logits(hidden_states, self.lm_head, |
| 188 | + sampling_metadata, spec_step_idx) |
| 189 | + |
| 190 | + def sample( |
| 191 | + self, |
| 192 | + logits: torch.Tensor, |
| 193 | + sampling_metadata: SamplingMetadata, |
| 194 | + ) -> Optional[SamplerOutput]: |
| 195 | + next_tokens = self.sampler(logits, sampling_metadata) |
| 196 | + return next_tokens |
| 197 | + |
| 198 | + def load_weights(self, weights: Iterable[tuple[str, |
| 199 | + torch.Tensor]]) -> set[str]: |
| 200 | + stacked_params_mapping = [ |
| 201 | + ("qkv_proj", "q_proj", "q"), |
| 202 | + ("qkv_proj", "k_proj", "k"), |
| 203 | + ("qkv_proj", "v_proj", "v"), |
| 204 | + ("gate_up_proj", "gate_proj", 0), |
| 205 | + ("gate_up_proj", "up_proj", 1), |
| 206 | + ] |
| 207 | + |
| 208 | + params_dict = dict(self.named_parameters()) |
| 209 | + loaded_params: set[str] = set() |
| 210 | + for name, loaded_weight in weights: |
| 211 | + |
| 212 | + if self.config.tie_word_embeddings and name.endswith( |
| 213 | + "lm_head.weight"): |
| 214 | + continue |
| 215 | + if "rotary_emb.inv_freq" in name: |
| 216 | + continue |
| 217 | + if "mtp" in name: |
| 218 | + name = self._rewrite_spec_layer_name(self.config, name) |
| 219 | + |
| 220 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 221 | + # Skip non-stacked layers and experts (experts handled below). |
| 222 | + if weight_name not in name: |
| 223 | + continue |
| 224 | + if "mtp" not in name: |
| 225 | + continue |
| 226 | + # We have mlp.experts[0].gate_proj in the checkpoint. |
| 227 | + # Since we handle the experts below in expert_params_mapping, |
| 228 | + # we need to skip here BEFORE we update the name, otherwise |
| 229 | + # name will be updated to mlp.experts[0].gate_up_proj, which |
| 230 | + # will then be updated below in expert_params_mapping |
| 231 | + # for mlp.experts[0].gate_gate_up_proj, which breaks load. |
| 232 | + if (("mlp.experts." in name) and name not in params_dict): |
| 233 | + continue |
| 234 | + name = name.replace(weight_name, param_name) |
| 235 | + # Skip loading extra bias for GPTQ models. |
| 236 | + if ((name.endswith(".bias") or name.endswith("_bias")) |
| 237 | + and name not in params_dict): |
| 238 | + continue |
| 239 | + # Skip layers on other devices. |
| 240 | + if is_pp_missing_parameter(name, self): |
| 241 | + continue |
| 242 | + |
| 243 | + param = params_dict[name] |
| 244 | + weight_loader = param.weight_loader |
| 245 | + weight_loader(param, loaded_weight, shard_id) |
| 246 | + break |
| 247 | + else: |
| 248 | + # Skip loading extra bias for GPTQ models. |
| 249 | + if ((name.endswith(".bias") or name.endswith("_bias")) |
| 250 | + and name not in params_dict): |
| 251 | + continue |
| 252 | + # Skip layers on other devices. |
| 253 | + if is_pp_missing_parameter(name, self): |
| 254 | + continue |
| 255 | + |
| 256 | + # According to DeepSeek-V3 Technical Report, MTP modules |
| 257 | + # shares embedding layer. We only load the first weights. |
| 258 | + if "mtp_" not in name and ("embed_tokens" not in name |
| 259 | + and "lm_head" not in name): |
| 260 | + continue |
| 261 | + |
| 262 | + param = params_dict[name] |
| 263 | + weight_loader = getattr(param, "weight_loader", |
| 264 | + default_weight_loader) |
| 265 | + weight_loader(param, loaded_weight) |
| 266 | + loaded_params.add(name) |
| 267 | + return loaded_params |
| 268 | + |
| 269 | + def _rewrite_spec_layer_name(self, config: PretrainedConfig, |
| 270 | + name: str) -> str: |
| 271 | + """ |
| 272 | + Rewrite the weight name to match the format of the original model. |
| 273 | + """ |
| 274 | + spec_layer_weight_names = [ |
| 275 | + "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", |
| 276 | + "mtp_linear_proj" |
| 277 | + ] |
| 278 | + layer_idx = config.num_hidden_layers |
| 279 | + for weight_name in spec_layer_weight_names: |
| 280 | + if weight_name in name: |
| 281 | + name = name.replace( |
| 282 | + f"model.{weight_name}.0.", |
| 283 | + f"model.layers.{layer_idx}.{weight_name}.") |
| 284 | + return name |
| 285 | + name = name.replace("model.mtp_block.0.", |
| 286 | + f"model.layers.{layer_idx}.mtp_block.") |
| 287 | + return name |
0 commit comments