Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,38 @@
# Remove this patch if upstream provides an official NPU graph-capture
# guidance / auto-configuration path for HCCL.
#
# 3. `vllm.config.speculative.SpeculativeConfig._verify_args`
# Why:
# Upstream vLLM's eagle3/extract_hidden_states restricts target model types
# via a whitelist. MiniMax-M2 should be allowed once the worker-side model
# can emit auxiliary hidden states.
# How:
# Monkey-patch `_verify_args` to bypass only the whitelist ValueError for
# MiniMax model_type when method is eagle3/extract_hidden_states.
# SpeculativeConfig is a Pydantic dataclass (`@config`); init validation calls
# `__pydantic_decorators__.model_validators["_verify_args"].func`, so that
# `Decorator.func` must be replaced (not only `SpeculativeConfig._verify_args`),
# then `rebuild_dataclass(SpeculativeConfig, force=True)`.
# If `VllmConfig` was imported earlier, also `rebuild_dataclass(VllmConfig, ...)`
# so nested `speculative_config` validation does not use a stale schema.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Remove this patch once upstream whitelist includes MiniMax.
#
# 4. `vllm.model_executor.models.registry` (spec decode aliases)
# Why:
# Some Eagle3 draft checkpoints may declare a MiniMax-specific architecture
# string while reusing the shared Eagle3 implementation.
# How:
# Register `Eagle3MiniMaxM2ForCausalLM` as an alias pointing to the
# existing Eagle3 implementation in the speculative decoding registry.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Drop the alias once upstream registry includes it or the checkpoint
# standardizes architecture strings.
#
# ** 8. File: platform/patch_kv_cache_interface.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
Expand Down Expand Up @@ -436,6 +468,31 @@
# Future Plan:
# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU.
#
# 4. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.forward`
# Why:
# Eagle3 speculative decoding needs auxiliary hidden states from specific
# transformer layers of the target model.
# How:
# Extend `MiniMaxM2Model.forward` to optionally collect and return
# `(final_hidden_states, aux_hidden_states)` when `aux_hidden_state_layers`
# is set by the runtime.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Remove this patch once upstream MiniMax-M2 integrates Eagle3 support.
#
# 5. `vllm.model_executor.models.minimax_m2.MiniMaxM2ForCausalLM`
# Why:
# vLLM core uses SupportsEagle3-style methods to configure which layers
# should emit auxiliary hidden states.
# How:
# Inject `set_aux_hidden_state_layers` and default-layer getters onto
# `MiniMaxM2ForCausalLM` so vLLM can configure the target model.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Remove this patch once upstream provides these methods on the model.
#
Comment on lines +471 to +495
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The patch for vllm.model_executor.models.minimax_m2.MiniMaxM2Attention.forward in worker/patch_minimax_m2.py is not documented here. This file serves as a manifest for all patches, and for maintainability, it's important to keep it complete and up-to-date. Please add documentation for this new patch. It appears to be a performance optimization using the fused kernel torch.ops.vllm.split_qkv_tp_rmsnorm_rope.

You could add it as item 3 and renumber the subsequent items under ** 17. File: worker/patch_minimax_m2.py**.

# ** 18. File: worker/patch_minimax_m2_linear_attn.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__`
Expand Down
146 changes: 146 additions & 0 deletions vllm_ascend/patch/platform/patch_minimax_m2_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,149 @@ def _patched_verify_cuda_graph(self: ModelConfig) -> None:

if _original_verify_cuda_graph is not None:
ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph


# ---------------------------------------------------------------------------
# Speculative decoding (Eagle3): allow MiniMax targets and registry alias.
# ---------------------------------------------------------------------------
def _patch_speculative_minimax_whitelist() -> None:
"""Allow MiniMax target models for eagle3/extract_hidden_states checks.

Upstream vLLM validates that the target model_type is in a whitelist for
methods that rely on auxiliary hidden states. Older upstream versions may
not include MiniMax yet.
"""
try:
from vllm.config.speculative import SpeculativeConfig # type: ignore
except Exception:
logger.warning(
"SpeculativeConfig is not found, skip patching eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
)
return

original_verify_args = getattr(SpeculativeConfig, "_verify_args", None)
if original_verify_args is None:
logger.warning(
"SpeculativeConfig._verify_args is not found, skip patching "
"eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
)
return
if getattr(original_verify_args, "_vllm_ascend_minimax_eagle3_patched", False):
logger.warning("eagle3/extract_hidden_states checks for MiniMax-M2 on NPU have already been patched.")
return

# Pydantic dataclass validation invokes `model_validators["_verify_args"].func`, not
# necessarily the current `SpeculativeConfig._verify_args` attribute.
decorators = getattr(SpeculativeConfig, "__pydantic_decorators__", None)
mv = None
if decorators is not None:
model_validators = getattr(decorators, "model_validators", None)
if isinstance(model_validators, dict):
mv = model_validators.get("_verify_args")
inner_verify = mv.func if mv is not None and getattr(mv, "func", None) is not None else original_verify_args

def _patched_verify_args(self, *args, **kwargs): # type: ignore[no-untyped-def]
try:
return inner_verify(self, *args, **kwargs)
except ValueError as e:
method = getattr(self, "method", None)
if method not in ("eagle3", "extract_hidden_states"):
raise

target_cfg = getattr(self, "target_model_config", None)
model_type = getattr(getattr(target_cfg, "hf_text_config", None), "model_type", "")
if "minimax" not in str(model_type).lower():
logger.warning(
"Model type %s is not a MiniMax-M2 model, skip eagle3/extract_hidden_states checks.",
model_type,
)
raise

msg = str(e).lower()
if "only supported for" in msg and "models" in msg:
# Upstream `_verify_args` calls `verify_equal_vocab_size_if_draft_model` after
# the aux-hidden allowlist; returning here would skip it.
verify_vocab = getattr(self, "verify_equal_vocab_size_if_draft_model", None)
if callable(verify_vocab):
verify_vocab()
return self
raise

_patched_verify_args._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
SpeculativeConfig._verify_args = _patched_verify_args # type: ignore[assignment]

if mv is not None:
try:
mv.func = _patched_verify_args # type: ignore[misc]
except (TypeError, AttributeError):
object.__setattr__(mv, "func", _patched_verify_args)
else:
logger.warning(
"Could not find SpeculativeConfig.__pydantic_decorators__.model_validators["
"'_verify_args']; eagle3 whitelist patch may not run at init validation."
)

try:
from pydantic.dataclasses import rebuild_dataclass # type: ignore
except Exception as e:
logger.warning(
"Cannot import rebuild_dataclass (%s); SpeculativeConfig eagle3 whitelist "
"patch may not apply at instance construction time.",
e,
)
else:
try:
rebuild_dataclass(SpeculativeConfig, force=True) # type: ignore[arg-type]
except Exception as e:
logger.warning(
"rebuild_dataclass(SpeculativeConfig) failed (%s); eagle3 whitelist patch may not apply.",
e,
)
# If `VllmConfig` was imported before this patch ran, its pydantic-core schema
# for the nested `speculative_config` field may still embed the *pre-patch*
# SpeculativeConfig validators. `create_speculative_config()` calls
# `SpeculativeConfig(...)` directly (uses updated class validator), but
# `VllmConfig(..., speculative_config=...)` validates via the parent's cached
# nested schema and can still raise the whitelist error unless we rebuild.
try:
from vllm.config.vllm import VllmConfig # type: ignore
except Exception:
pass
else:
try:
rebuild_dataclass(VllmConfig, force=True) # type: ignore[arg-type]
except Exception as e:
logger.warning(
"rebuild_dataclass(VllmConfig) failed (%s); VllmConfig(...) may "
"still use stale nested SpeculativeConfig validation.",
e,
)


def _patch_eagle3_registry_alias() -> None:
"""Register Eagle3MiniMaxM2ForCausalLM architecture alias if missing."""
try:
import vllm.model_executor.models.registry as registry # type: ignore
except Exception:
return

# Prefer patching the underlying dicts when available.
if hasattr(registry, "_SPECULATIVE_DECODING_MODELS"):
models = registry._SPECULATIVE_DECODING_MODELS
if isinstance(models, dict):
models.setdefault("Eagle3MiniMaxM2ForCausalLM", ("llama_eagle3", "Eagle3LlamaForCausalLM"))

# Fallback: patch resolved registry instance if present.
model_registry = getattr(registry, "ModelRegistry", None)
if model_registry is not None and hasattr(model_registry, "models"):
try:
model_registry.models.setdefault( # type: ignore[attr-defined]
"Eagle3MiniMaxM2ForCausalLM",
("llama_eagle3", "Eagle3LlamaForCausalLM"),
)
except Exception:
return


_patch_speculative_minimax_whitelist()
_patch_eagle3_registry_alias()
135 changes: 110 additions & 25 deletions vllm_ascend/patch/worker/patch_minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@

import torch
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE
from vllm.model_executor.models.minimax_m2 import (
MiniMaxM2Attention,
MiniMaxM2ForCausalLM,
MiniMaxM2Model,
MiniMaxM2MoE,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice

Expand Down Expand Up @@ -87,6 +94,34 @@ def _patched_attention_init(self, *args, **kwargs) -> None:
MiniMaxM2Attention.__init__ = _patched_attention_init


def _patch_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
cos, sin = get_cos_and_sin_slice()
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
input=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
eps=self.q_norm.variance_epsilon,
tp_world=self.q_norm.tp_world,
cos=cos,
sin=sin,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output


MiniMaxM2Attention.forward = _patch_forward


# ---------------------------------------------------------------------------
# MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -176,29 +211,79 @@ def _patched_load_weights(
MiniMaxM2Model.load_weights = _patched_load_weights


def _patch_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
cos, sin = get_cos_and_sin_slice()
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
input=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
eps=self.q_norm.variance_epsilon,
tp_world=self.q_norm.tp_world,
cos=cos,
sin=sin,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
# ---------------------------------------------------------------------------
# MiniMaxM2Model / MiniMaxM2ForCausalLM: Eagle3 aux hidden states support
# ---------------------------------------------------------------------------
_original_minimax_m2_forward = MiniMaxM2Model.forward


MiniMaxM2Attention.forward = _patch_forward
def _patched_minimax_m2_forward(
self: "MiniMaxM2Model",
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
aux_layers: tuple[int, ...] = getattr(self, "aux_hidden_state_layers", ()) or ()
if not aux_layers:
return _original_minimax_m2_forward(self, input_ids, positions, intermediate_tensors, inputs_embeds)

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
layer_idx = self.start_layer + idx
if layer_idx in aux_layers:
aux_hidden_states.append(hidden_states + residual if residual is not None else hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})

hidden_states, _ = self.norm(hidden_states, residual)
if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states


if not getattr(_original_minimax_m2_forward, "_vllm_ascend_minimax_eagle3_patched", False):
MiniMaxM2Model.forward = _patched_minimax_m2_forward # type: ignore[assignment]
MiniMaxM2Model.forward._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]


def _set_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM", layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = tuple(int(x) for x in layers)


def _get_eagle3_default_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)


def _get_eagle3_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
return _get_eagle3_default_aux_hidden_state_layers(self)


# vLLM 0.18+: `supports_eagle3(model)` is `isinstance(model, SupportsEagle3)` (see
# `vllm.model_executor.models.interfaces`). `SupportsEagle3` extends `SupportsEagleBase`;
# runtime protocol checks require class attributes below (not only Eagle3 methods), or
# isinstance fails and model_runner_v1 raises:
# "Model does not support EAGLE3 interface but aux_hidden_state_outputs was requested".
MiniMaxM2ForCausalLM.has_own_lm_head = False # type: ignore[misc]
MiniMaxM2ForCausalLM.has_own_embed_tokens = False # type: ignore[misc]
MiniMaxM2ForCausalLM.supports_eagle3 = True # type: ignore[misc]

MiniMaxM2ForCausalLM.set_aux_hidden_state_layers = _set_aux_hidden_state_layers # type: ignore[attr-defined]
MiniMaxM2ForCausalLM.get_eagle3_default_aux_hidden_state_layers = ( # type: ignore[attr-defined]
_get_eagle3_default_aux_hidden_state_layers
)
MiniMaxM2ForCausalLM.get_eagle3_aux_hidden_state_layers = _get_eagle3_aux_hidden_state_layers # type: ignore[attr-defined]
Loading