14
14
from torch import nn
15
15
from transformers .configuration_utils import PretrainedConfig
16
16
17
+ from vllm import envs
17
18
from vllm .attention import Attention , AttentionMetadata
18
- from vllm .config import CacheConfig , VllmConfig
19
+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
19
20
from vllm .distributed .communication_op import tensor_model_parallel_all_reduce
20
21
from vllm .distributed .parallel_state import (
21
22
get_pp_group , get_tensor_model_parallel_rank ,
33
34
ReplicatedLinear ,
34
35
RowParallelLinear )
35
36
from vllm .model_executor .layers .logits_processor import LogitsProcessor
37
+ from vllm .model_executor .layers .mamba .abstract import MambaBase
38
+ from vllm .model_executor .layers .mamba .mamba_utils import (
39
+ MambaStateShapeCalculator )
36
40
from vllm .model_executor .layers .quantization .base_config import (
37
41
QuantizationConfig )
38
42
from vllm .model_executor .layers .vocab_parallel_embedding import (
41
45
from vllm .model_executor .models .utils import maybe_prefix
42
46
from vllm .model_executor .sampling_metadata import SamplingMetadata
43
47
from vllm .sequence import IntermediateTensors
48
+ from vllm .v1 .attention .backends .linear_attn import LinearAttentionMetadata
44
49
45
- from .interfaces import HasInnerState , IsHybrid , SupportsV0Only
50
+ from .interfaces import HasInnerState , IsHybrid
46
51
from .minimax_cache import MinimaxCacheManager , MinimaxCacheParams
47
52
from .utils import PPMissingLayer , is_pp_missing_parameter , make_layers
48
53
@@ -327,7 +332,17 @@ def jit_linear_forward_prefix(q: torch.Tensor,
327
332
return rearrange (output .squeeze (0 ), "h n d -> n (h d)" )
328
333
329
334
330
- class MiniMaxText01LinearAttention (nn .Module ):
335
+ class MiniMaxText01LinearAttention (nn .Module , MambaBase ):
336
+
337
+ @property
338
+ def mamba_type (self ) -> str :
339
+ return "linear_attention"
340
+
341
+ def get_state_shape (self ) -> tuple [tuple [int , ...], tuple [int , ...]]:
342
+ return MambaStateShapeCalculator .linear_attention_state_shape (
343
+ num_heads = self .num_heads ,
344
+ tp_size = self .tp_size ,
345
+ head_dim = self .head_dim )
331
346
332
347
def __init__ (
333
348
self ,
@@ -359,6 +374,7 @@ def __init__(
359
374
self .tp_heads = self .total_num_heads // self .tp_size
360
375
self .qkv_size = self .num_heads * self .head_dim
361
376
self .tp_hidden = self .head_dim * self .tp_heads
377
+ self .prefix = prefix
362
378
363
379
self .qkv_proj = ColumnParallelLinear (
364
380
hidden_size ,
@@ -397,6 +413,12 @@ def __init__(
397
413
self .tp_heads :(self .tp_rank + 1 ) *
398
414
self .tp_heads ].contiguous ()
399
415
416
+ if envs .VLLM_USE_V1 :
417
+ compilation_config = get_current_vllm_config ().compilation_config
418
+ if prefix in compilation_config .static_forward_context :
419
+ raise ValueError (f"Duplicate layer name: { prefix } " )
420
+ compilation_config .static_forward_context [prefix ] = self
421
+
400
422
@staticmethod
401
423
def weight_direct_load (param : torch .Tensor ,
402
424
loaded_weight : torch .Tensor ) -> None :
@@ -434,13 +456,14 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
434
456
break
435
457
if _prefill_idx >= len (state_indices_tensor ):
436
458
break
437
- _start = attn_metadata .query_start_loc [_prefill_idx ]
438
- _end = attn_metadata .query_start_loc [_prefill_idx + 1 ]
439
- slot_id = state_indices_tensor [_prefill_idx ]
459
+ # prefills are packed at end of batch in V1
460
+ offset = attn_metadata .num_decode_tokens if envs .VLLM_USE_V1 else 0
461
+ _start = attn_metadata .query_start_loc [offset + _prefill_idx ]
462
+ _end = attn_metadata .query_start_loc [offset + _prefill_idx + 1 ]
463
+ slot_id = state_indices_tensor [offset + _prefill_idx ]
440
464
qs = q [_start :_end ].transpose (0 , 1 ).contiguous ()
441
465
ks = k [_start :_end ].transpose (0 , 1 ).contiguous ()
442
466
vs = v [_start :_end ].transpose (0 , 1 ).contiguous ()
443
- slot_id = state_indices_tensor [_prefill_idx ]
444
467
slice_layer_cache = kv_cache [slot_id , ...]
445
468
446
469
out_slice = MiniMaxText01LinearKernel .jit_linear_forward_prefix (
@@ -453,9 +476,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
453
476
layer_idx = self .layer_idx )
454
477
hidden .append (out_slice .contiguous ())
455
478
if attn_metadata .num_decode_tokens > 0 :
456
- hidden .append (
457
- self ._decode_infer (q , k , v , kv_cache , state_indices_tensor ,
458
- attn_metadata ))
479
+ hidden_decode = self ._decode_infer (q , k , v , kv_cache ,
480
+ state_indices_tensor ,
481
+ attn_metadata )
482
+ if envs .VLLM_USE_V1 :
483
+ hidden .insert (0 , hidden_decode )
484
+ else :
485
+ hidden .append (hidden_decode )
459
486
460
487
if not hidden :
461
488
return torch .empty ((0 , q .size (- 1 )), device = q .device , dtype = q .dtype )
@@ -465,11 +492,17 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
465
492
466
493
def _decode_infer (self , q , k , v , kv_cache , state_indices_tensor ,
467
494
attn_metadata ):
468
- q = q [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
469
- k = k [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
470
- v = v [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
471
- slot_id = state_indices_tensor [getattr (attn_metadata , "num_prefills" , 0
472
- ):]
495
+ if not envs .VLLM_USE_V1 :
496
+ q = q [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
497
+ k = k [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
498
+ v = v [attn_metadata .num_prefill_tokens :].unsqueeze (2 ).contiguous ()
499
+ num_prefills = getattr (attn_metadata , "num_prefills" , 0 )
500
+ slot_id = state_indices_tensor [num_prefills :]
501
+ else :
502
+ q = q [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
503
+ k = k [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
504
+ v = v [:attn_metadata .num_decode_tokens ].unsqueeze (2 ).contiguous ()
505
+ slot_id = state_indices_tensor [:attn_metadata .num_decodes ]
473
506
hidden = linear_decode_forward_triton (q , k , v , kv_cache , self .tp_slope ,
474
507
slot_id , 32 )
475
508
return hidden
@@ -483,17 +516,49 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
483
516
q , k , v = torch .split (qkvact , [self .head_dim ] * 3 , dim = - 1 )
484
517
forward_context = get_forward_context ()
485
518
attn_metadata = forward_context .attn_metadata
486
- kv_cache = kv_caches .minimax_cache
487
- state_indices_tensor = kv_caches .state_indices_tensor
519
+ if envs .VLLM_USE_V1 :
520
+ if attn_metadata is not None :
521
+ assert isinstance (attn_metadata , dict )
522
+ attn_metadata = attn_metadata [self .prefix ]
523
+ assert isinstance (attn_metadata , LinearAttentionMetadata )
524
+ kv_cache = self .kv_cache [forward_context .virtual_engine ][0 ]
525
+ state_indices_tensor = attn_metadata .state_indices_tensor
526
+
527
+ num_prefills = getattr (attn_metadata , "num_prefills" , 0 )
528
+ if num_prefills > 0 :
529
+ num_decode_tokens = getattr (attn_metadata ,
530
+ "num_decode_tokens" , 0 )
531
+ for prefill_idx in range (num_prefills ):
532
+ q_start = attn_metadata .query_start_loc [
533
+ num_decode_tokens + prefill_idx ]
534
+ q_end = attn_metadata .query_start_loc [num_decode_tokens
535
+ + prefill_idx +
536
+ 1 ]
537
+ query_len = q_end - q_start
538
+ context_len = attn_metadata .seq_lens [
539
+ num_decode_tokens + prefill_idx ] - query_len
540
+ if context_len == 0 :
541
+ block_to_clear = state_indices_tensor [
542
+ num_decode_tokens + prefill_idx ]
543
+ kv_cache [block_to_clear , ...] = 0
544
+ else :
545
+ kv_cache = kv_caches .minimax_cache
546
+ state_indices_tensor = kv_caches .state_indices_tensor
488
547
489
548
decode_only = getattr (attn_metadata , "num_prefills" , 0 ) == 0
490
- if not decode_only :
491
- hidden = self . _prefill_and_mix_infer ( q , k , v , kv_cache ,
492
- state_indices_tensor ,
493
- attn_metadata )
549
+ if attn_metadata is None :
550
+ hidden = torch . empty (( q . shape [ 0 ], q . shape [ 1 ] * q . shape [ 2 ]) ,
551
+ device = q . device ,
552
+ dtype = q . dtype )
494
553
else :
495
- hidden = self ._decode_infer (q , k , v , kv_cache ,
496
- state_indices_tensor , attn_metadata )
554
+ if not decode_only :
555
+ hidden = self ._prefill_and_mix_infer (q , k , v , kv_cache ,
556
+ state_indices_tensor ,
557
+ attn_metadata )
558
+ else :
559
+ hidden = self ._decode_infer (q , k , v , kv_cache ,
560
+ state_indices_tensor ,
561
+ attn_metadata )
497
562
498
563
hidden = self .norm ._forward (hidden )
499
564
gate , _ = self .output_gate (hidden_states )
@@ -541,6 +606,7 @@ def __init__(
541
606
self .scaling = self .head_dim ** - 0.5
542
607
self .rope_theta = rope_theta
543
608
self .sliding_window = sliding_window
609
+ self .prefix = prefix
544
610
545
611
self .qkv_proj = QKVParallelLinear (
546
612
hidden_size ,
@@ -575,7 +641,12 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
575
641
attn_metadata = forward_context .attn_metadata
576
642
qkv , _ = self .qkv_proj (hidden_states )
577
643
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
578
- q , k = attn_metadata .rotary_emb (positions , q , k )
644
+ if envs .VLLM_USE_V1 :
645
+ if attn_metadata is not None :
646
+ q , k = attn_metadata [f"{ self .prefix } .attn" ].rotary_emb (
647
+ positions , q , k )
648
+ else :
649
+ q , k = attn_metadata .rotary_emb (positions , q , k )
579
650
attn_output = self .attn (q , k , v )
580
651
output , _ = self .o_proj (attn_output )
581
652
return output
@@ -595,6 +666,7 @@ def __init__(
595
666
) -> None :
596
667
self ._ilayer = layer_id
597
668
self ._irank = get_tensor_model_parallel_rank ()
669
+ self .prefix = prefix
598
670
super ().__init__ ()
599
671
600
672
self .hidden_size = config .hidden_size
@@ -876,8 +948,9 @@ def layer_fn(prefix):
876
948
self ._dtype = _dummy .dtype
877
949
del _dummy
878
950
879
- self .minimax_cache = MinimaxCacheManager (dtype = torch .float32 ,
880
- cache_shape = self .cache_shape )
951
+ if not envs .VLLM_USE_V1 :
952
+ self .minimax_cache = MinimaxCacheManager (
953
+ dtype = torch .float32 , cache_shape = self .cache_shape )
881
954
882
955
rope_theta = getattr (config , "rope_theta" , 10000 )
883
956
head_dim = getattr (config , "head_dim" , None )
@@ -944,23 +1017,27 @@ def forward(self,
944
1017
** kwargs ) -> Union [torch .Tensor , IntermediateTensors ]:
945
1018
forward_context = get_forward_context ()
946
1019
attn_metadata = forward_context .attn_metadata
947
- if attn_metadata is None :
1020
+ if not envs . VLLM_USE_V1 and attn_metadata is None :
948
1021
return None
949
1022
if "request_ids_to_seq_ids" not in kwargs :
950
1023
kwargs ["request_ids_to_seq_ids" ] = {}
951
1024
if "finished_requests_ids" not in kwargs :
952
1025
kwargs ["finished_requests_ids" ] = []
953
1026
954
- (
955
- minimax_cache_tensors ,
956
- state_indices_tensor ,
957
- ) = self .minimax_cache .current_run_tensors (** kwargs )
958
- if getattr (attn_metadata , "num_prefills" , 0 ) > 0 :
959
- self ._clear_prefill_cache (attn_metadata , minimax_cache_tensors ,
960
- ** kwargs )
1027
+ if not envs .VLLM_USE_V1 :
1028
+ (
1029
+ minimax_cache_tensors ,
1030
+ state_indices_tensor ,
1031
+ ) = self .minimax_cache .current_run_tensors (** kwargs )
1032
+ if getattr (attn_metadata , "num_prefills" , 0 ) > 0 :
1033
+ self ._clear_prefill_cache (attn_metadata , minimax_cache_tensors ,
1034
+ ** kwargs )
1035
+
1036
+ minimax_cache_params = MinimaxCacheParams (minimax_cache_tensors ,
1037
+ state_indices_tensor )
1038
+ else :
1039
+ minimax_cache_params = None
961
1040
962
- minimax_cache_params = MinimaxCacheParams (minimax_cache_tensors ,
963
- state_indices_tensor )
964
1041
if get_pp_group ().is_first_rank :
965
1042
if inputs_embeds is None :
966
1043
hidden_states = self .embed_scale * self .embed_tokens (input_ids )
@@ -973,11 +1050,22 @@ def forward(self,
973
1050
residual = intermediate_tensors ["residual" ]
974
1051
975
1052
minimax_cache_index = 0
976
- attn_metadata . rotary_emb = self . rotary_emb
1053
+
977
1054
for i in range (self .start_layer , self .end_layer ):
978
1055
layer = self .layers [i ]
1056
+ if attn_metadata is not None :
1057
+ # TODO (tdoublep): this whole thing with the rotary_emb is
1058
+ # weird. we shouldn't be passing it via attn_metadata imo.
1059
+ if envs .VLLM_USE_V1 :
1060
+ if isinstance (layer .self_attn , MiniMaxText01Attention ):
1061
+ attn_metadata [layer .prefix +
1062
+ ".attn" ].rotary_emb = self .rotary_emb
1063
+ else :
1064
+ attn_metadata .rotary_emb = self .rotary_emb
1065
+
979
1066
_caches = None
980
- if isinstance (layer .self_attn , MiniMaxText01LinearAttention ):
1067
+ if not envs .VLLM_USE_V1 and isinstance (
1068
+ layer .self_attn , MiniMaxText01LinearAttention ):
981
1069
current_state_layer = minimax_cache_index
982
1070
_caches = minimax_cache_params .at_layer_idx (
983
1071
current_state_layer )
@@ -1002,8 +1090,7 @@ def forward(self,
1002
1090
return hidden_states
1003
1091
1004
1092
1005
- class MiniMaxText01ForCausalLM (nn .Module , HasInnerState , IsHybrid ,
1006
- SupportsV0Only ):
1093
+ class MiniMaxText01ForCausalLM (nn .Module , HasInnerState , IsHybrid ):
1007
1094
1008
1095
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ) -> None :
1009
1096
@@ -1321,3 +1408,28 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
1321
1408
1322
1409
load_basic_weight (name , loaded_weight , self )
1323
1410
return loaded_params
1411
+
1412
+ @classmethod
1413
+ def get_mamba_state_shape_from_config (
1414
+ cls ,
1415
+ vllm_config : "VllmConfig" ,
1416
+ use_v1 : bool = True ,
1417
+ ) -> tuple [tuple [int , ...], ...]:
1418
+ """Calculate shape for MiniMaxText01LinearAttention cache.
1419
+
1420
+ Args:
1421
+ vllm_config: vLLM config
1422
+ use_v1: Get shapes for V1 (or V0)
1423
+
1424
+ Returns:
1425
+ Tuple containing:
1426
+ - state_shape: Shape of the cache
1427
+ """
1428
+ parallel_config = vllm_config .parallel_config
1429
+ hf_config = vllm_config .model_config .hf_config
1430
+
1431
+ return MambaStateShapeCalculator .linear_attention_state_shape (
1432
+ num_heads = hf_config .num_attention_heads ,
1433
+ tp_size = parallel_config .tensor_parallel_size ,
1434
+ head_dim = hf_config .head_dim ,
1435
+ )
0 commit comments