Skip to content

Commit 6ade99e

Browse files
authored
[V1] [Hybrid] Support Minimax-Text-01 in V1 (#22151)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 3157aeb commit 6ade99e

File tree

5 files changed

+234
-42
lines changed

5 files changed

+234
-42
lines changed

vllm/model_executor/layers/lightning_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
532532
pid_d = tl.program_id(2) # dimension block index
533533

534534
# Load slot index for the current batch
535-
slot_id = tl.load(slot_idx + pid_b)
535+
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
536536

537537
# Skip if slot_id is -1 (padding)
538538
if slot_id == -1:

vllm/model_executor/layers/mamba/mamba_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55

66
class MambaStateShapeCalculator:
77

8+
@classmethod
9+
def linear_attention_state_shape(
10+
cls,
11+
num_heads: int,
12+
tp_size: int,
13+
head_dim: int,
14+
) -> tuple[tuple[int, int, int], ...]:
15+
16+
state_shape = (num_heads // tp_size, head_dim, head_dim)
17+
return (state_shape, )
18+
819
@classmethod
920
def mamba1_state_shape(
1021
cls,

vllm/model_executor/models/minimax_text_01.py

Lines changed: 152 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from torch import nn
1515
from transformers.configuration_utils import PretrainedConfig
1616

17+
from vllm import envs
1718
from vllm.attention import Attention, AttentionMetadata
18-
from vllm.config import CacheConfig, VllmConfig
19+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
1920
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
2021
from vllm.distributed.parallel_state import (
2122
get_pp_group, get_tensor_model_parallel_rank,
@@ -33,6 +34,9 @@
3334
ReplicatedLinear,
3435
RowParallelLinear)
3536
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37+
from vllm.model_executor.layers.mamba.abstract import MambaBase
38+
from vllm.model_executor.layers.mamba.mamba_utils import (
39+
MambaStateShapeCalculator)
3640
from vllm.model_executor.layers.quantization.base_config import (
3741
QuantizationConfig)
3842
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -41,8 +45,9 @@
4145
from vllm.model_executor.models.utils import maybe_prefix
4246
from vllm.model_executor.sampling_metadata import SamplingMetadata
4347
from vllm.sequence import IntermediateTensors
48+
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
4449

45-
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
50+
from .interfaces import HasInnerState, IsHybrid
4651
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
4752
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
4853

@@ -327,7 +332,17 @@ def jit_linear_forward_prefix(q: torch.Tensor,
327332
return rearrange(output.squeeze(0), "h n d -> n (h d)")
328333

329334

330-
class MiniMaxText01LinearAttention(nn.Module):
335+
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
336+
337+
@property
338+
def mamba_type(self) -> str:
339+
return "linear_attention"
340+
341+
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
342+
return MambaStateShapeCalculator.linear_attention_state_shape(
343+
num_heads=self.num_heads,
344+
tp_size=self.tp_size,
345+
head_dim=self.head_dim)
331346

332347
def __init__(
333348
self,
@@ -359,6 +374,7 @@ def __init__(
359374
self.tp_heads = self.total_num_heads // self.tp_size
360375
self.qkv_size = self.num_heads * self.head_dim
361376
self.tp_hidden = self.head_dim * self.tp_heads
377+
self.prefix = prefix
362378

363379
self.qkv_proj = ColumnParallelLinear(
364380
hidden_size,
@@ -397,6 +413,12 @@ def __init__(
397413
self.tp_heads:(self.tp_rank + 1) *
398414
self.tp_heads].contiguous()
399415

416+
if envs.VLLM_USE_V1:
417+
compilation_config = get_current_vllm_config().compilation_config
418+
if prefix in compilation_config.static_forward_context:
419+
raise ValueError(f"Duplicate layer name: {prefix}")
420+
compilation_config.static_forward_context[prefix] = self
421+
400422
@staticmethod
401423
def weight_direct_load(param: torch.Tensor,
402424
loaded_weight: torch.Tensor) -> None:
@@ -434,13 +456,14 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
434456
break
435457
if _prefill_idx >= len(state_indices_tensor):
436458
break
437-
_start = attn_metadata.query_start_loc[_prefill_idx]
438-
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
439-
slot_id = state_indices_tensor[_prefill_idx]
459+
# prefills are packed at end of batch in V1
460+
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
461+
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
462+
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
463+
slot_id = state_indices_tensor[offset + _prefill_idx]
440464
qs = q[_start:_end].transpose(0, 1).contiguous()
441465
ks = k[_start:_end].transpose(0, 1).contiguous()
442466
vs = v[_start:_end].transpose(0, 1).contiguous()
443-
slot_id = state_indices_tensor[_prefill_idx]
444467
slice_layer_cache = kv_cache[slot_id, ...]
445468

446469
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
@@ -453,9 +476,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
453476
layer_idx=self.layer_idx)
454477
hidden.append(out_slice.contiguous())
455478
if attn_metadata.num_decode_tokens > 0:
456-
hidden.append(
457-
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
458-
attn_metadata))
479+
hidden_decode = self._decode_infer(q, k, v, kv_cache,
480+
state_indices_tensor,
481+
attn_metadata)
482+
if envs.VLLM_USE_V1:
483+
hidden.insert(0, hidden_decode)
484+
else:
485+
hidden.append(hidden_decode)
459486

460487
if not hidden:
461488
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
@@ -465,11 +492,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
465492

466493
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
467494
attn_metadata):
468-
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
469-
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
470-
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
471-
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
472-
):]
495+
if not envs.VLLM_USE_V1:
496+
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
497+
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
498+
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
499+
num_prefills = getattr(attn_metadata, "num_prefills", 0)
500+
slot_id = state_indices_tensor[num_prefills:]
501+
else:
502+
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
503+
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
504+
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
505+
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
473506
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
474507
slot_id, 32)
475508
return hidden
@@ -483,17 +516,49 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
483516
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
484517
forward_context = get_forward_context()
485518
attn_metadata = forward_context.attn_metadata
486-
kv_cache = kv_caches.minimax_cache
487-
state_indices_tensor = kv_caches.state_indices_tensor
519+
if envs.VLLM_USE_V1:
520+
if attn_metadata is not None:
521+
assert isinstance(attn_metadata, dict)
522+
attn_metadata = attn_metadata[self.prefix]
523+
assert isinstance(attn_metadata, LinearAttentionMetadata)
524+
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
525+
state_indices_tensor = attn_metadata.state_indices_tensor
526+
527+
num_prefills = getattr(attn_metadata, "num_prefills", 0)
528+
if num_prefills > 0:
529+
num_decode_tokens = getattr(attn_metadata,
530+
"num_decode_tokens", 0)
531+
for prefill_idx in range(num_prefills):
532+
q_start = attn_metadata.query_start_loc[
533+
num_decode_tokens + prefill_idx]
534+
q_end = attn_metadata.query_start_loc[num_decode_tokens
535+
+ prefill_idx +
536+
1]
537+
query_len = q_end - q_start
538+
context_len = attn_metadata.seq_lens[
539+
num_decode_tokens + prefill_idx] - query_len
540+
if context_len == 0:
541+
block_to_clear = state_indices_tensor[
542+
num_decode_tokens + prefill_idx]
543+
kv_cache[block_to_clear, ...] = 0
544+
else:
545+
kv_cache = kv_caches.minimax_cache
546+
state_indices_tensor = kv_caches.state_indices_tensor
488547

489548
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
490-
if not decode_only:
491-
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
492-
state_indices_tensor,
493-
attn_metadata)
549+
if attn_metadata is None:
550+
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
551+
device=q.device,
552+
dtype=q.dtype)
494553
else:
495-
hidden = self._decode_infer(q, k, v, kv_cache,
496-
state_indices_tensor, attn_metadata)
554+
if not decode_only:
555+
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
556+
state_indices_tensor,
557+
attn_metadata)
558+
else:
559+
hidden = self._decode_infer(q, k, v, kv_cache,
560+
state_indices_tensor,
561+
attn_metadata)
497562

498563
hidden = self.norm._forward(hidden)
499564
gate, _ = self.output_gate(hidden_states)
@@ -541,6 +606,7 @@ def __init__(
541606
self.scaling = self.head_dim**-0.5
542607
self.rope_theta = rope_theta
543608
self.sliding_window = sliding_window
609+
self.prefix = prefix
544610

545611
self.qkv_proj = QKVParallelLinear(
546612
hidden_size,
@@ -575,7 +641,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
575641
attn_metadata = forward_context.attn_metadata
576642
qkv, _ = self.qkv_proj(hidden_states)
577643
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
578-
q, k = attn_metadata.rotary_emb(positions, q, k)
644+
if envs.VLLM_USE_V1:
645+
if attn_metadata is not None:
646+
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
647+
positions, q, k)
648+
else:
649+
q, k = attn_metadata.rotary_emb(positions, q, k)
579650
attn_output = self.attn(q, k, v)
580651
output, _ = self.o_proj(attn_output)
581652
return output
@@ -595,6 +666,7 @@ def __init__(
595666
) -> None:
596667
self._ilayer = layer_id
597668
self._irank = get_tensor_model_parallel_rank()
669+
self.prefix = prefix
598670
super().__init__()
599671

600672
self.hidden_size = config.hidden_size
@@ -876,8 +948,9 @@ def layer_fn(prefix):
876948
self._dtype = _dummy.dtype
877949
del _dummy
878950

879-
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
880-
cache_shape=self.cache_shape)
951+
if not envs.VLLM_USE_V1:
952+
self.minimax_cache = MinimaxCacheManager(
953+
dtype=torch.float32, cache_shape=self.cache_shape)
881954

882955
rope_theta = getattr(config, "rope_theta", 10000)
883956
head_dim = getattr(config, "head_dim", None)
@@ -944,23 +1017,27 @@ def forward(self,
9441017
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
9451018
forward_context = get_forward_context()
9461019
attn_metadata = forward_context.attn_metadata
947-
if attn_metadata is None:
1020+
if not envs.VLLM_USE_V1 and attn_metadata is None:
9481021
return None
9491022
if "request_ids_to_seq_ids" not in kwargs:
9501023
kwargs["request_ids_to_seq_ids"] = {}
9511024
if "finished_requests_ids" not in kwargs:
9521025
kwargs["finished_requests_ids"] = []
9531026

954-
(
955-
minimax_cache_tensors,
956-
state_indices_tensor,
957-
) = self.minimax_cache.current_run_tensors(**kwargs)
958-
if getattr(attn_metadata, "num_prefills", 0) > 0:
959-
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
960-
**kwargs)
1027+
if not envs.VLLM_USE_V1:
1028+
(
1029+
minimax_cache_tensors,
1030+
state_indices_tensor,
1031+
) = self.minimax_cache.current_run_tensors(**kwargs)
1032+
if getattr(attn_metadata, "num_prefills", 0) > 0:
1033+
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
1034+
**kwargs)
1035+
1036+
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
1037+
state_indices_tensor)
1038+
else:
1039+
minimax_cache_params = None
9611040

962-
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
963-
state_indices_tensor)
9641041
if get_pp_group().is_first_rank:
9651042
if inputs_embeds is None:
9661043
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
@@ -973,11 +1050,22 @@ def forward(self,
9731050
residual = intermediate_tensors["residual"]
9741051

9751052
minimax_cache_index = 0
976-
attn_metadata.rotary_emb = self.rotary_emb
1053+
9771054
for i in range(self.start_layer, self.end_layer):
9781055
layer = self.layers[i]
1056+
if attn_metadata is not None:
1057+
# TODO (tdoublep): this whole thing with the rotary_emb is
1058+
# weird. we shouldn't be passing it via attn_metadata imo.
1059+
if envs.VLLM_USE_V1:
1060+
if isinstance(layer.self_attn, MiniMaxText01Attention):
1061+
attn_metadata[layer.prefix +
1062+
".attn"].rotary_emb = self.rotary_emb
1063+
else:
1064+
attn_metadata.rotary_emb = self.rotary_emb
1065+
9791066
_caches = None
980-
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
1067+
if not envs.VLLM_USE_V1 and isinstance(
1068+
layer.self_attn, MiniMaxText01LinearAttention):
9811069
current_state_layer = minimax_cache_index
9821070
_caches = minimax_cache_params.at_layer_idx(
9831071
current_state_layer)
@@ -1002,8 +1090,7 @@ def forward(self,
10021090
return hidden_states
10031091

10041092

1005-
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
1006-
SupportsV0Only):
1093+
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
10071094

10081095
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10091096

@@ -1321,3 +1408,28 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
13211408

13221409
load_basic_weight(name, loaded_weight, self)
13231410
return loaded_params
1411+
1412+
@classmethod
1413+
def get_mamba_state_shape_from_config(
1414+
cls,
1415+
vllm_config: "VllmConfig",
1416+
use_v1: bool = True,
1417+
) -> tuple[tuple[int, ...], ...]:
1418+
"""Calculate shape for MiniMaxText01LinearAttention cache.
1419+
1420+
Args:
1421+
vllm_config: vLLM config
1422+
use_v1: Get shapes for V1 (or V0)
1423+
1424+
Returns:
1425+
Tuple containing:
1426+
- state_shape: Shape of the cache
1427+
"""
1428+
parallel_config = vllm_config.parallel_config
1429+
hf_config = vllm_config.model_config.hf_config
1430+
1431+
return MambaStateShapeCalculator.linear_attention_state_shape(
1432+
num_heads=hf_config.num_attention_heads,
1433+
tp_size=parallel_config.tensor_parallel_size,
1434+
head_dim=hf_config.head_dim,
1435+
)

0 commit comments

Comments
 (0)