Skip to content

Commit b690e34

Browse files
authored
[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)
Signed-off-by: Chih-Chieh Yang <[email protected]> Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 25373b6 commit b690e34

File tree

9 files changed

+144
-118
lines changed

9 files changed

+144
-118
lines changed

tests/kernels/mamba/test_mamba_ssm.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
365365
batch_size = 1
366366
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
367367
x = torch.randn(batch_size, dim, device=device, dtype=itype)
368+
out = torch.empty_like(x)
368369
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
369370
dt_bias = torch.rand(dim, device=device) - 4.0
370371
A = -torch.rand(dim, dstate, device=device) - 1.0
@@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype):
373374
D = torch.randn(dim, device=device)
374375
z = torch.randn_like(x) if has_z else None
375376
state_ref = state.detach().clone()
376-
out = selective_state_update(state,
377-
x,
378-
dt,
379-
A,
380-
B,
381-
C,
382-
D=D,
383-
z=z,
384-
dt_bias=dt_bias,
385-
dt_softplus=True)
377+
selective_state_update(state,
378+
x,
379+
dt,
380+
A,
381+
B,
382+
C,
383+
D=D,
384+
z=z,
385+
dt_bias=dt_bias,
386+
dt_softplus=True,
387+
out=out)
386388
out_ref = selective_state_update_ref(state_ref,
387389
x,
388390
dt,
@@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
581583
],
582584
dim=0)
583585
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
586+
out = torch.empty_like(x)
584587
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
585588
dt_bias = torch.rand(dim, device=device) - 4.0
586589
A = -torch.rand(dim, dstate, device=device) - 1.0
@@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
590593
z = torch.randn_like(x) if has_z else None
591594
state_ref = state[state_indices, :].clone()
592595
state_before = state.clone()
593-
out = selective_state_update(state,
594-
x,
595-
dt,
596-
A,
597-
B,
598-
C,
599-
D=D,
600-
z=z,
601-
dt_bias=dt_bias,
602-
dt_softplus=True,
603-
state_batch_indices=padded_state_indices,
604-
pad_slot_id=PAD_SLOT_ID)
596+
selective_state_update(state,
597+
x,
598+
dt,
599+
A,
600+
B,
601+
C,
602+
D=D,
603+
z=z,
604+
dt_bias=dt_bias,
605+
dt_softplus=True,
606+
state_batch_indices=padded_state_indices,
607+
pad_slot_id=PAD_SLOT_ID,
608+
out=out)
605609
out_ref = selective_state_update_ref(state_ref,
606610
x[:batch_size],
607611
dt[:batch_size],
@@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
665669
dtype=torch.int32, device=device)
666670

667671
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
672+
out = torch.empty_like(x)
668673
if not tie_hdim:
669674
dt = torch.randn(batch_size,
670675
nheads,
@@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices(
691696
C = torch.randn(batch_size, ngroups, dstate, device=device)
692697
z = torch.randn_like(x) if has_z else None
693698
state_ref = state[state_indices, :].detach().clone()
694-
out = selective_state_update(state,
695-
x,
696-
dt,
697-
A,
698-
B,
699-
C,
700-
D=D,
701-
z=z,
702-
dt_bias=dt_bias,
703-
dt_softplus=True,
704-
state_batch_indices=state_indices,
705-
pad_slot_id=PAD_SLOT_ID)
699+
selective_state_update(state,
700+
x,
701+
dt,
702+
A,
703+
B,
704+
C,
705+
D=D,
706+
z=z,
707+
dt_bias=dt_bias,
708+
dt_softplus=True,
709+
state_batch_indices=state_indices,
710+
pad_slot_id=PAD_SLOT_ID,
711+
out=out)
706712
out_ref = selective_state_update_ref(state_ref,
707713
x,
708714
dt,

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
212212

213213
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
214214
B, C, chunk_size)
215-
216-
Y, final_state = mamba_chunk_scan_combined(X,
217-
dt,
218-
A,
219-
B,
220-
C,
221-
chunk_size,
222-
D=None,
223-
return_final_states=True)
215+
Y = torch.empty_like(X)
216+
final_state = mamba_chunk_scan_combined(X,
217+
dt,
218+
A,
219+
B,
220+
C,
221+
chunk_size,
222+
D=None,
223+
return_final_states=True,
224+
out=Y)
224225

225226
# just test the last in sequence
226227
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
@@ -292,7 +293,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
292293
_query_start_loc_to_chunk_indices_offsets(
293294
cu_seqlens, chunk_size, cu_seqlens[-1])
294295

295-
Y, new_states = mamba_chunk_scan_combined(
296+
Y = torch.empty_like(X)
297+
new_states = mamba_chunk_scan_combined(
296298
X,
297299
dt,
298300
A,
@@ -306,6 +308,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
306308
chunk_offsets=chunk_offsets,
307309
return_varlen_states=True,
308310
initial_states=states,
311+
out=Y,
309312
)
310313

311314
# just test the last in sequence

vllm/model_executor/layers/mamba/mamba_mixer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def forward_cuda(self, hidden_states: torch.Tensor,
220220
has_initial_state=attn_metadata.context_lens_tensor > 0,
221221
query_start_loc=attn_metadata.query_start_loc)
222222
else:
223-
scan_outputs = selective_state_update(
223+
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
224+
selective_state_update(
224225
mamba_cache_params.ssm_state,
225226
hidden_states.transpose(0, 1),
226227
discrete_time_step.transpose(0, 1),
@@ -231,7 +232,8 @@ def forward_cuda(self, hidden_states: torch.Tensor,
231232
gate.transpose(0, 1),
232233
time_proj_bias,
233234
dt_softplus=True,
234-
state_batch_indices=mamba_cache_params.state_indices_tensor)
235+
state_batch_indices=mamba_cache_params.state_indices_tensor,
236+
out=scan_outputs)
235237
scan_outputs = scan_outputs.transpose(0, 1)
236238

237239
# 4. Final linear projection

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ def forward_cuda(
541541
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
542542
# Separate prefill and decode by splitting varlen input
543543
# Split along token dimension
544-
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
545544
if envs.VLLM_USE_V1:
546545
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
547546
hidden_states_B_C[:num_actual_tokens],
@@ -583,7 +582,28 @@ def forward_cuda(
583582
1]
584583
if has_prefill else None)
585584

586-
ssd_output_list = []
585+
# Preallocate output tensor to avoid memcpy cost for merging prefill
586+
# and decode outputs
587+
preallocated_ssm_out = torch.empty(
588+
[
589+
num_prefill_tokens + num_decodes,
590+
(self.num_heads // self.tp_size) * self.head_dim
591+
],
592+
dtype=hidden_states.dtype,
593+
device=hidden_states.device,
594+
)
595+
if envs.VLLM_USE_V1:
596+
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
597+
preallocated_ssm_out,
598+
[num_decodes, num_prefill_tokens],
599+
dim=0,
600+
)
601+
else:
602+
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
603+
preallocated_ssm_out,
604+
[num_prefill_tokens, num_decodes],
605+
dim=0,
606+
)
587607

588608
# Process prefill requests
589609
if has_prefill:
@@ -623,7 +643,8 @@ def forward_cuda(
623643
has_initial_states_p[:num_prefills, None, None, None],
624644
ssm_state[state_indices_tensor_p], 0)
625645

626-
scan_output, varlen_state = mamba_chunk_scan_combined(
646+
# NOTE: final output is an in-place update of out tensor
647+
varlen_state = mamba_chunk_scan_combined(
627648
hidden_states_p.view(1, num_prefill_tokens,
628649
self.num_heads // self.tp_size,
629650
self.head_dim),
@@ -646,15 +667,14 @@ def forward_cuda(
646667
return_final_states=False,
647668
dt_softplus=True,
648669
dt_limit=(0.0, float("inf")),
670+
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
671+
self.head_dim),
649672
)
650673

651674
# update ssm states
652675
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
653676
ssm_state[state_indices_tensor_p] = varlen_state
654677

655-
# - reshape
656-
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
657-
658678
# Process decode requests
659679
if has_decode:
660680
# 2. Convolution sequence transformation
@@ -684,8 +704,8 @@ def forward_cuda(
684704
# - the hidden is reshaped into (bs, num_heads, head_dim)
685705
# - mamba_cache_params.ssm_state's slots will be selected
686706
# using state_indices_tensor_d
687-
688-
hidden_states_d = selective_state_update(
707+
# NOTE: final output is an in-place update of out tensor
708+
selective_state_update(
689709
ssm_state,
690710
hidden_states_d,
691711
dt_d,
@@ -697,26 +717,16 @@ def forward_cuda(
697717
dt_bias=dt_bias,
698718
dt_softplus=True,
699719
state_batch_indices=state_indices_tensor_d,
720+
out=preallocated_ssm_out_d.view(num_decodes, -1,
721+
self.head_dim),
700722
)
701723

702-
if envs.VLLM_USE_V1:
703-
ssd_output_list.insert(
704-
0,
705-
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
706-
self.head_dim))
707-
else:
708-
ssd_output_list.append(
709-
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
710-
self.head_dim))
711-
712-
# Merge prefill and decode outputs before passing to gated MLP
713-
hidden_states = torch.vstack(ssd_output_list)
714-
715724
# 4. gated MLP
716725
# GatedRMSNorm internally applying SiLU to the gate
717726
# SiLU is applied internally before normalization, unlike standard
718727
# norm usage
719-
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
728+
hidden_states = self.norm(preallocated_ssm_out,
729+
gate[:num_actual_tokens])
720730

721731
# 5. Final linear projection
722732
output[:num_actual_tokens], _ = self.out_proj(hidden_states)

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def selective_state_update(state,
205205
dt_bias=None,
206206
dt_softplus=False,
207207
state_batch_indices=None,
208-
pad_slot_id=PAD_SLOT_ID):
208+
pad_slot_id=PAD_SLOT_ID,
209+
out=None):
209210
"""
210211
Argument:
211212
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@@ -223,10 +224,9 @@ def selective_state_update(state,
223224
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
224225
in this case, the kernel will not process entries at
225226
indices 0 and 3
226-
Return:
227-
out: (batch, dim) or (batch, nheads, dim)
227+
out: Preallocated ssm output tensor. Assume same shape as x.
228+
In-place updated.
228229
"""
229-
has_heads = state.dim() > 3
230230
if state.dim() == 3:
231231
state = state.unsqueeze(1)
232232
if x.dim() == 2:
@@ -245,6 +245,8 @@ def selective_state_update(state,
245245
z = z.unsqueeze(1)
246246
if dt_bias is not None and dt_bias.dim() == 1:
247247
dt_bias = dt_bias.unsqueeze(0)
248+
if out.dim() == 2:
249+
out = out.unsqueeze(1)
248250

249251
_, nheads, dim, dstate = state.shape
250252
batch = x.shape[0]
@@ -264,7 +266,8 @@ def selective_state_update(state,
264266
assert dt_bias.shape == (nheads, dim)
265267
if state_batch_indices is not None:
266268
assert state_batch_indices.shape == (batch, )
267-
out = torch.empty_like(x)
269+
assert out.shape == x.shape
270+
268271
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
269272
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
270273
(0, 0, 0))
@@ -328,9 +331,6 @@ def selective_state_update(state,
328331
BLOCK_SIZE_M,
329332
num_warps=num_warps,
330333
)
331-
if not has_heads:
332-
out = out.squeeze(1)
333-
return out
334334

335335

336336
def selective_scan_fn(u,

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def _chunk_scan_fwd(
454454
chunk_indices=None,
455455
chunk_offsets=None,
456456
initial_states=None,
457+
out=None,
457458
):
458459
batch, seqlen, nheads, headdim = x.shape
459460
_, _, nchunks, chunk_size = dt.shape
@@ -483,20 +484,10 @@ def _chunk_scan_fwd(
483484
else:
484485
chunk_indices, chunk_offsets = None, None
485486

486-
# Allocates output.
487-
out = torch.empty(batch,
488-
seqlen,
489-
nheads,
490-
headdim,
491-
device=x.device,
492-
dtype=x.dtype)
487+
assert out.shape == x.shape
488+
493489
if z is not None:
494-
out_x = torch.empty(batch,
495-
seqlen,
496-
nheads,
497-
headdim,
498-
device=x.device,
499-
dtype=x.dtype)
490+
out_x = torch.empty_like(x)
500491
assert out_x.stride() == out.stride()
501492
else:
502493
out_x = None
@@ -579,4 +570,4 @@ def _chunk_scan_fwd(
579570
IS_TRITON_22=TRITON_22,
580571
HAS_INITSTATES=initial_states is not None,
581572
)
582-
return out, out_x
573+
return out_x

0 commit comments

Comments
 (0)