6
6
from torch import nn
7
7
8
8
from vllm .attention .backends .abstract import AttentionMetadata
9
- from vllm .attention .backends .flash_attn import FlashAttentionMetadata
10
- from vllm .attention .backends .placeholder_attn import (
11
- PlaceholderAttentionMetadata )
12
- from vllm .attention .backends .xformers import XFormersMetadata
13
9
from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
14
10
get_tensor_model_parallel_world_size ,
15
11
tensor_model_parallel_all_gather ,
18
14
from vllm .model_executor .custom_op import CustomOp
19
15
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
20
16
RowParallelLinear )
17
+ from vllm .model_executor .layers .mamba .mamba2_metadata import Mamba2Metadata
21
18
from vllm .model_executor .layers .mamba .ops .causal_conv1d import (
22
19
causal_conv1d_fn , causal_conv1d_update )
23
20
from vllm .model_executor .layers .mamba .ops .mamba_ssm import (
@@ -221,7 +218,6 @@ def __init__(self,
221
218
head_dim : int = 64 ,
222
219
rms_norm_eps : float = 1e-5 ,
223
220
activation = "silu" ,
224
- chunk_size : int = 256 ,
225
221
quant_config : Optional [QuantizationConfig ] = None ):
226
222
super ().__init__ ()
227
223
@@ -257,7 +253,6 @@ def __init__(self,
257
253
self .ssm_state_size = ssm_state_size
258
254
self .activation = activation
259
255
260
- self .chunk_size = chunk_size
261
256
self .intermediate_size = intermediate_size
262
257
self .head_dim = head_dim
263
258
self .num_heads = num_heads
@@ -388,25 +383,17 @@ def forward_cuda(
388
383
self ,
389
384
hidden_states : torch .Tensor ,
390
385
mamba_cache_params : MambaCacheParams ,
391
- sequence_idx : Optional [ torch . Tensor ] = None ,
386
+ mamba2_metadata : Mamba2Metadata ,
392
387
):
388
+ # mamba2_metadata contains metadata necessary for the mamba2 triton
389
+ # kernels to operate in continuous batching and in chunked prefill
390
+ # modes; they are computed at top-level model forward since they
391
+ # are the same and reused for all mamba layers in the same iteration
393
392
attn_metadata : AttentionMetadata = get_forward_context ().attn_metadata
394
393
395
394
seq_len , _ = hidden_states .shape
396
395
groups_time_state_size = self .n_groups * self .ssm_state_size
397
396
398
- # detect if there are prefills
399
- has_prefill = attn_metadata .num_prefills > 0
400
-
401
- # - also need flags to indicate if there are initial states
402
- # - currently we really only support the FlashAttention backend
403
- has_initial_states = None
404
- if (isinstance (attn_metadata ,
405
- (FlashAttentionMetadata , XFormersMetadata ,
406
- PlaceholderAttentionMetadata ))
407
- and attn_metadata .context_lens_tensor is not None ):
408
- has_initial_states = attn_metadata .context_lens_tensor > 0
409
-
410
397
# 1. Gated MLP's linear projection
411
398
projected_states , _ = self .in_proj (hidden_states )
412
399
gate , hidden_states_B_C , dt = torch .split (
@@ -423,7 +410,7 @@ def forward_cuda(
423
410
conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ),
424
411
self .conv1d .weight .size (2 ))
425
412
426
- if has_prefill :
413
+ if mamba2_metadata . has_prefill :
427
414
# |---------- N-1 iteration --------|
428
415
# |---------------- N iteration ---------------------|
429
416
# |- tokenA -|......................|-- newTokens ---|
@@ -439,7 +426,7 @@ def forward_cuda(
439
426
self .conv1d .bias ,
440
427
activation = self .activation ,
441
428
conv_states = mamba_cache_params .conv_state ,
442
- has_initial_state = has_initial_states ,
429
+ has_initial_state = mamba2_metadata . has_initial_states ,
443
430
cache_indices = mamba_cache_params .state_indices_tensor ,
444
431
query_start_loc = attn_metadata .query_start_loc ).transpose (
445
432
0 , 1 )[:seq_len ]
@@ -467,16 +454,15 @@ def forward_cuda(
467
454
)
468
455
469
456
# 3. State Space Model sequence transformation
470
- if has_prefill :
471
-
457
+ if mamba2_metadata .has_prefill :
472
458
initial_states = None
473
- if has_initial_states is not None and torch . any (
474
- has_initial_states ):
475
- zero_init_indices = mamba_cache_params . state_indices_tensor [
476
- ~ has_initial_states ]
477
- mamba_cache_params . ssm_state [ zero_init_indices ] = 0
478
- initial_states = mamba_cache_params .ssm_state [
479
- mamba_cache_params .state_indices_tensor ]
459
+ if ( mamba2_metadata . has_initial_states is not None
460
+ and mamba2_metadata . prep_initial_states ):
461
+ # making a copy of the states
462
+ initial_states = torch . where (
463
+ mamba2_metadata . has_initial_states [:, None , None , None ],
464
+ mamba_cache_params .ssm_state [
465
+ mamba_cache_params .state_indices_tensor ], 0 )
480
466
481
467
scan_output , varlen_state = mamba_chunk_scan_combined (
482
468
hidden_states .view (1 , seq_len , self .num_heads // self .tp_size ,
@@ -485,11 +471,13 @@ def forward_cuda(
485
471
self .A ,
486
472
B .view (1 , seq_len , self .n_groups // self .tp_size , - 1 ),
487
473
C .view (1 , seq_len , self .n_groups // self .tp_size , - 1 ),
488
- chunk_size = self .chunk_size ,
474
+ chunk_size = mamba2_metadata .chunk_size ,
489
475
D = self .D ,
490
476
z = None ,
491
477
dt_bias = self .dt_bias ,
492
- seq_idx = sequence_idx ,
478
+ seq_idx = mamba2_metadata .seq_idx ,
479
+ chunk_indices = mamba2_metadata .chunk_indices ,
480
+ chunk_offsets = mamba2_metadata .chunk_offsets ,
493
481
cu_seqlens = attn_metadata .query_start_loc ,
494
482
initial_states = initial_states ,
495
483
return_varlen_states = True ,
0 commit comments