diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4b167db09b9..16badfbde0d 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -137,6 +137,38 @@ # Remove this patch if upstream provides an official NPU graph-capture # guidance / auto-configuration path for HCCL. # +# 3. `vllm.config.speculative.SpeculativeConfig._verify_args` +# Why: +# Upstream vLLM's eagle3/extract_hidden_states restricts target model types +# via a whitelist. MiniMax-M2 should be allowed once the worker-side model +# can emit auxiliary hidden states. +# How: +# Monkey-patch `_verify_args` to bypass only the whitelist ValueError for +# MiniMax model_type when method is eagle3/extract_hidden_states. +# SpeculativeConfig is a Pydantic dataclass (`@config`); init validation calls +# `__pydantic_decorators__.model_validators["_verify_args"].func`, so that +# `Decorator.func` must be replaced (not only `SpeculativeConfig._verify_args`), +# then `rebuild_dataclass(SpeculativeConfig, force=True)`. +# If `VllmConfig` was imported earlier, also `rebuild_dataclass(VllmConfig, ...)` +# so nested `speculative_config` validation does not use a stale schema. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/37512 +# Future Plan: +# Remove this patch once upstream whitelist includes MiniMax. +# +# 4. `vllm.model_executor.models.registry` (spec decode aliases) +# Why: +# Some Eagle3 draft checkpoints may declare a MiniMax-specific architecture +# string while reusing the shared Eagle3 implementation. +# How: +# Register `Eagle3MiniMaxM2ForCausalLM` as an alias pointing to the +# existing Eagle3 implementation in the speculative decoding registry. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/37512 +# Future Plan: +# Drop the alias once upstream registry includes it or the checkpoint +# standardizes architecture strings. +# # ** 8. File: platform/patch_kv_cache_interface.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec` @@ -436,6 +468,31 @@ # Future Plan: # Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU. # +# 4. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.forward` +# Why: +# Eagle3 speculative decoding needs auxiliary hidden states from specific +# transformer layers of the target model. +# How: +# Extend `MiniMaxM2Model.forward` to optionally collect and return +# `(final_hidden_states, aux_hidden_states)` when `aux_hidden_state_layers` +# is set by the runtime. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/37512 +# Future Plan: +# Remove this patch once upstream MiniMax-M2 integrates Eagle3 support. +# +# 5. `vllm.model_executor.models.minimax_m2.MiniMaxM2ForCausalLM` +# Why: +# vLLM core uses SupportsEagle3-style methods to configure which layers +# should emit auxiliary hidden states. +# How: +# Inject `set_aux_hidden_state_layers` and default-layer getters onto +# `MiniMaxM2ForCausalLM` so vLLM can configure the target model. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/37512 +# Future Plan: +# Remove this patch once upstream provides these methods on the model. +# # ** 18. File: worker/patch_minimax_m2_linear_attn.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__` diff --git a/vllm_ascend/patch/platform/patch_minimax_m2_config.py b/vllm_ascend/patch/platform/patch_minimax_m2_config.py index 22ead420dca..ce046be348c 100644 --- a/vllm_ascend/patch/platform/patch_minimax_m2_config.py +++ b/vllm_ascend/patch/platform/patch_minimax_m2_config.py @@ -134,3 +134,149 @@ def _patched_verify_cuda_graph(self: ModelConfig) -> None: if _original_verify_cuda_graph is not None: ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph + + +# --------------------------------------------------------------------------- +# Speculative decoding (Eagle3): allow MiniMax targets and registry alias. +# --------------------------------------------------------------------------- +def _patch_speculative_minimax_whitelist() -> None: + """Allow MiniMax target models for eagle3/extract_hidden_states checks. + + Upstream vLLM validates that the target model_type is in a whitelist for + methods that rely on auxiliary hidden states. Older upstream versions may + not include MiniMax yet. + """ + try: + from vllm.config.speculative import SpeculativeConfig # type: ignore + except Exception: + logger.warning( + "SpeculativeConfig is not found, skip patching eagle3/extract_hidden_states checks for MiniMax-M2 on NPU." + ) + return + + original_verify_args = getattr(SpeculativeConfig, "_verify_args", None) + if original_verify_args is None: + logger.warning( + "SpeculativeConfig._verify_args is not found, skip patching " + "eagle3/extract_hidden_states checks for MiniMax-M2 on NPU." + ) + return + if getattr(original_verify_args, "_vllm_ascend_minimax_eagle3_patched", False): + logger.warning("eagle3/extract_hidden_states checks for MiniMax-M2 on NPU have already been patched.") + return + + # Pydantic dataclass validation invokes `model_validators["_verify_args"].func`, not + # necessarily the current `SpeculativeConfig._verify_args` attribute. + decorators = getattr(SpeculativeConfig, "__pydantic_decorators__", None) + mv = None + if decorators is not None: + model_validators = getattr(decorators, "model_validators", None) + if isinstance(model_validators, dict): + mv = model_validators.get("_verify_args") + inner_verify = mv.func if mv is not None and getattr(mv, "func", None) is not None else original_verify_args + + def _patched_verify_args(self, *args, **kwargs): # type: ignore[no-untyped-def] + try: + return inner_verify(self, *args, **kwargs) + except ValueError as e: + method = getattr(self, "method", None) + if method not in ("eagle3", "extract_hidden_states"): + raise + + target_cfg = getattr(self, "target_model_config", None) + model_type = getattr(getattr(target_cfg, "hf_text_config", None), "model_type", "") + if "minimax" not in str(model_type).lower(): + logger.warning( + "Model type %s is not a MiniMax-M2 model, skip eagle3/extract_hidden_states checks.", + model_type, + ) + raise + + msg = str(e).lower() + if "only supported for" in msg and "models" in msg: + # Upstream `_verify_args` calls `verify_equal_vocab_size_if_draft_model` after + # the aux-hidden allowlist; returning here would skip it. + verify_vocab = getattr(self, "verify_equal_vocab_size_if_draft_model", None) + if callable(verify_vocab): + verify_vocab() + return self + raise + + _patched_verify_args._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined] + SpeculativeConfig._verify_args = _patched_verify_args # type: ignore[assignment] + + if mv is not None: + try: + mv.func = _patched_verify_args # type: ignore[misc] + except (TypeError, AttributeError): + object.__setattr__(mv, "func", _patched_verify_args) + else: + logger.warning( + "Could not find SpeculativeConfig.__pydantic_decorators__.model_validators[" + "'_verify_args']; eagle3 whitelist patch may not run at init validation." + ) + + try: + from pydantic.dataclasses import rebuild_dataclass # type: ignore + except Exception as e: + logger.warning( + "Cannot import rebuild_dataclass (%s); SpeculativeConfig eagle3 whitelist " + "patch may not apply at instance construction time.", + e, + ) + else: + try: + rebuild_dataclass(SpeculativeConfig, force=True) # type: ignore[arg-type] + except Exception as e: + logger.warning( + "rebuild_dataclass(SpeculativeConfig) failed (%s); eagle3 whitelist patch may not apply.", + e, + ) + # If `VllmConfig` was imported before this patch ran, its pydantic-core schema + # for the nested `speculative_config` field may still embed the *pre-patch* + # SpeculativeConfig validators. `create_speculative_config()` calls + # `SpeculativeConfig(...)` directly (uses updated class validator), but + # `VllmConfig(..., speculative_config=...)` validates via the parent's cached + # nested schema and can still raise the whitelist error unless we rebuild. + try: + from vllm.config.vllm import VllmConfig # type: ignore + except Exception: + pass + else: + try: + rebuild_dataclass(VllmConfig, force=True) # type: ignore[arg-type] + except Exception as e: + logger.warning( + "rebuild_dataclass(VllmConfig) failed (%s); VllmConfig(...) may " + "still use stale nested SpeculativeConfig validation.", + e, + ) + + +def _patch_eagle3_registry_alias() -> None: + """Register Eagle3MiniMaxM2ForCausalLM architecture alias if missing.""" + try: + import vllm.model_executor.models.registry as registry # type: ignore + except Exception: + return + + # Prefer patching the underlying dicts when available. + if hasattr(registry, "_SPECULATIVE_DECODING_MODELS"): + models = registry._SPECULATIVE_DECODING_MODELS + if isinstance(models, dict): + models.setdefault("Eagle3MiniMaxM2ForCausalLM", ("llama_eagle3", "Eagle3LlamaForCausalLM")) + + # Fallback: patch resolved registry instance if present. + model_registry = getattr(registry, "ModelRegistry", None) + if model_registry is not None and hasattr(model_registry, "models"): + try: + model_registry.models.setdefault( # type: ignore[attr-defined] + "Eagle3MiniMaxM2ForCausalLM", + ("llama_eagle3", "Eagle3LlamaForCausalLM"), + ) + except Exception: + return + + +_patch_speculative_minimax_whitelist() +_patch_eagle3_registry_alias() diff --git a/vllm_ascend/patch/worker/patch_minimax_m2.py b/vllm_ascend/patch/worker/patch_minimax_m2.py index 544184019bd..a2794437bf7 100644 --- a/vllm_ascend/patch/worker/patch_minimax_m2.py +++ b/vllm_ascend/patch/worker/patch_minimax_m2.py @@ -21,12 +21,19 @@ import torch from vllm.distributed import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP -from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE +from vllm.model_executor.models.minimax_m2 import ( + MiniMaxM2Attention, + MiniMaxM2ForCausalLM, + MiniMaxM2Model, + MiniMaxM2MoE, +) from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice @@ -87,6 +94,34 @@ def _patched_attention_init(self, *args, **kwargs) -> None: MiniMaxM2Attention.__init__ = _patched_attention_init +def _patch_forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, +) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + cos, sin = get_cos_and_sin_slice() + q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope( + input=qkv, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim), + eps=self.q_norm.variance_epsilon, + tp_world=self.q_norm.tp_world, + cos=cos, + sin=sin, + ) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +MiniMaxM2Attention.forward = _patch_forward + + # --------------------------------------------------------------------------- # MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper # --------------------------------------------------------------------------- @@ -176,29 +211,79 @@ def _patched_load_weights( MiniMaxM2Model.load_weights = _patched_load_weights -def _patch_forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, -) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - cos, sin = get_cos_and_sin_slice() - q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope( - input=qkv, - q_weight=self.q_norm.weight, - k_weight=self.k_norm.weight, - q_hidden_size=self.q_size, - kv_hidden_size=self.kv_size, - head_dim=self.head_dim, - rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim), - eps=self.q_norm.variance_epsilon, - tp_world=self.q_norm.tp_world, - cos=cos, - sin=sin, - ) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output +# --------------------------------------------------------------------------- +# MiniMaxM2Model / MiniMaxM2ForCausalLM: Eagle3 aux hidden states support +# --------------------------------------------------------------------------- +_original_minimax_m2_forward = MiniMaxM2Model.forward -MiniMaxM2Attention.forward = _patch_forward +def _patched_minimax_m2_forward( + self: "MiniMaxM2Model", + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, +) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + aux_layers: tuple[int, ...] = getattr(self, "aux_hidden_state_layers", ()) or () + if not aux_layers: + return _original_minimax_m2_forward(self, input_ids, positions, intermediate_tensors, inputs_embeds) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states: list[torch.Tensor] = [] + for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + layer_idx = self.start_layer + idx + if layer_idx in aux_layers: + aux_hidden_states.append(hidden_states + residual if residual is not None else hidden_states) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + + hidden_states, _ = self.norm(hidden_states, residual) + if aux_hidden_states: + return hidden_states, aux_hidden_states + return hidden_states + + +if not getattr(_original_minimax_m2_forward, "_vllm_ascend_minimax_eagle3_patched", False): + MiniMaxM2Model.forward = _patched_minimax_m2_forward # type: ignore[assignment] + MiniMaxM2Model.forward._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined] + + +def _set_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM", layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = tuple(int(x) for x in layers) + + +def _get_eagle3_default_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + +def _get_eagle3_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]: + return _get_eagle3_default_aux_hidden_state_layers(self) + + +# vLLM 0.18+: `supports_eagle3(model)` is `isinstance(model, SupportsEagle3)` (see +# `vllm.model_executor.models.interfaces`). `SupportsEagle3` extends `SupportsEagleBase`; +# runtime protocol checks require class attributes below (not only Eagle3 methods), or +# isinstance fails and model_runner_v1 raises: +# "Model does not support EAGLE3 interface but aux_hidden_state_outputs was requested". +MiniMaxM2ForCausalLM.has_own_lm_head = False # type: ignore[misc] +MiniMaxM2ForCausalLM.has_own_embed_tokens = False # type: ignore[misc] +MiniMaxM2ForCausalLM.supports_eagle3 = True # type: ignore[misc] + +MiniMaxM2ForCausalLM.set_aux_hidden_state_layers = _set_aux_hidden_state_layers # type: ignore[attr-defined] +MiniMaxM2ForCausalLM.get_eagle3_default_aux_hidden_state_layers = ( # type: ignore[attr-defined] + _get_eagle3_default_aux_hidden_state_layers +) +MiniMaxM2ForCausalLM.get_eagle3_aux_hidden_state_layers = _get_eagle3_aux_hidden_state_layers # type: ignore[attr-defined]