Skip to content

Commit c75c2e7

Browse files
heheda12345simon-mo
authored andcommitted
[Deepseek v3.2] Support indexer prefill chunking (#25999)
Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: simon-mo <[email protected]>
1 parent 9d9a2b7 commit c75c2e7

File tree

3 files changed

+149
-79
lines changed

3 files changed

+149
-79
lines changed

tests/v1/attention/test_sparse_mla_backends.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.v1.attention.backends.mla.flashmla_sparse import (
2323
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
2424
FlashMLASparseImpl, FlashMLASparseMetadata)
25+
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
2526

2627
SPARSE_BACKEND_BATCH_SPECS = {
2728
name: BATCH_SPECS[name]
@@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
424425
sdpa_reference,
425426
rtol=0.5,
426427
atol=0.5)
428+
429+
430+
@pytest.mark.parametrize(
431+
"seq_lens,max_buf,start,expected",
432+
[
433+
# Basic split: totals per chunk ≤ max_buf
434+
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
435+
# Non-zero start index
436+
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
437+
# Exact fits should split between items when adding the next would
438+
# overflow
439+
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
440+
# All requests fit in a single chunk
441+
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
442+
# Large buffer with non-zero start
443+
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
444+
],
445+
)
446+
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
447+
out = split_prefill_chunks(seq_lens, max_buf, start)
448+
assert out == expected

vllm/model_executor/models/deepseek_v2.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -583,44 +583,43 @@ def sparse_attn_indexer(
583583
topk_indices_buffer[:hidden_states.shape[0]] = -1
584584
if has_prefill:
585585
prefill_metadata = attn_metadata.prefill
586-
num_prefills = attn_metadata.num_prefills
587-
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
588-
device=k.device,
589-
dtype=torch.float8_e4m3fn)
590-
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
591-
device=k.device,
592-
dtype=torch.float32)
593-
cp_gather_indexer_k_quant_cache(
594-
kv_cache,
595-
k_fp8,
596-
k_scale,
597-
prefill_metadata.block_table,
598-
prefill_metadata.cu_seq_lens,
599-
num_prefills,
600-
)
601-
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
602-
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
603-
num_tokens = attn_metadata.num_actual_tokens
604-
logits = fp8_mqa_logits(
605-
q_fp8[num_decode_tokens:num_tokens],
606-
(k_fp8, k_scale),
607-
weights[num_decode_tokens:num_tokens],
608-
cu_seqlen_ks,
609-
cu_seqlen_ke,
610-
)
611-
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
612-
dim=-1)[1]
613-
topk_indices -= cu_seqlen_ks[:, None]
614-
mask_lo = topk_indices >= 0
615-
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
616-
mask = torch.full_like(topk_indices,
617-
False,
618-
dtype=torch.bool,
619-
device=topk_indices.device)
620-
mask = mask_lo & mask_hi
621-
topk_indices = topk_indices.masked_fill(~mask, -1)
622-
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
623-
shape[-1]] = topk_indices.to(dtype=torch.int32)
586+
for chunk in prefill_metadata.chunks:
587+
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
588+
device=k.device,
589+
dtype=torch.float8_e4m3fn)
590+
k_scale = torch.empty([chunk.total_seq_lens, 1],
591+
device=k.device,
592+
dtype=torch.float32)
593+
cp_gather_indexer_k_quant_cache(
594+
kv_cache,
595+
k_fp8,
596+
k_scale,
597+
chunk.block_table,
598+
chunk.cu_seq_lens,
599+
chunk.num_reqs,
600+
)
601+
logits = fp8_mqa_logits(
602+
q_fp8[chunk.token_start:chunk.token_end],
603+
(k_fp8, k_scale),
604+
weights[chunk.token_start:chunk.token_end],
605+
chunk.cu_seqlen_ks,
606+
chunk.cu_seqlen_ke,
607+
)
608+
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
609+
dim=-1)[1]
610+
topk_indices -= chunk.cu_seqlen_ks[:, None]
611+
mask_lo = topk_indices >= 0
612+
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
613+
chunk.cu_seqlen_ks)[:, None] < 0
614+
mask = torch.full_like(topk_indices,
615+
False,
616+
dtype=torch.bool,
617+
device=topk_indices.device)
618+
mask = mask_lo & mask_hi
619+
topk_indices = topk_indices.masked_fill(~mask, -1)
620+
topk_indices_buffer[
621+
chunk.token_start:chunk.token_end, :topk_indices.
622+
shape[-1]] = topk_indices.to(dtype=torch.int32)
624623

625624
if has_decode:
626625
decode_metadata = attn_metadata.decode

vllm/v1/attention/backends/mla/indexer.py

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,20 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
4949

5050

5151
@dataclass
52-
class DeepseekV32IndexerPrefillMetadata:
52+
class DeepseekV32IndexerPrefillChunkMetadata:
5353
block_table: torch.Tensor
54-
query_start_loc: torch.Tensor
55-
max_query_len: int
5654
cu_seqlen_ks: torch.Tensor
5755
cu_seqlen_ke: torch.Tensor
5856
cu_seq_lens: torch.Tensor
5957
total_seq_lens: int
58+
token_start: int
59+
token_end: int
60+
num_reqs: int
61+
62+
63+
@dataclass
64+
class DeepseekV32IndexerPrefillMetadata:
65+
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
6066

6167

6268
@dataclass
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
98104

99105
# TODO (zyongye) optimize this, this is now vibe coded
100106
def kv_spans_from_batches(
101-
start_seq_loc: torch.Tensor,
102-
seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
107+
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
108+
device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
103109
"""
104110
Args:
105111
start_seq_loc: 1D long tensor [B+1], cumulative counts of
@@ -122,15 +128,14 @@ def kv_spans_from_batches(
122128
are the **last** `counts[i]` positions of that sequence.
123129
"""
124130
q = start_seq_loc.to(dtype=torch.long)
125-
L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
131+
L = seq_len_per_batch.to(dtype=torch.long)
126132
assert q.dim() == 1 and L.dim() == 1
127133
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
128134

129135
# Selected tokens per batch and totals
130136
counts = q[1:] - q[:-1] # [B]
131137
N = int(q[-1].item()) # total selected tokens
132138
B = L.numel()
133-
device = L.device
134139

135140
if N == 0:
136141
return (torch.empty(0, dtype=torch.long, device=device),
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
140145
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
141146

142147
# For each selected token, which batch does it belong to?
143-
batch_id = torch.repeat_interleave(torch.arange(B, device=device),
144-
counts) # [N]
148+
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
145149

146150
# Map batch KV start to each token
147151
start_tensor = kv_starts_per_batch[batch_id] # [N]
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
151155
L_expand = torch.repeat_interleave(L, counts) # [N]
152156
m_expand = torch.repeat_interleave(counts, counts) # [N]
153157
# position within the selected block: 1..counts[b]
154-
pos_within = (torch.arange(N, device=device, dtype=torch.long) -
158+
pos_within = (torch.arange(N, dtype=torch.long) -
155159
torch.repeat_interleave(q[:-1], counts) + 1)
156160

157161
local_pos = L_expand - m_expand + pos_within # [N], 1-based
158162
end_location = start_tensor + local_pos # exclusive end
159163

160-
return start_tensor.int(), end_location.int()
164+
return start_tensor.int().to(device), end_location.int().to(device)
161165

162166

163167
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
164168
max_model_len = vllm_config.model_config.max_model_len
165-
# max_num_batched_tokens = \
166-
# vllm_config.scheduler_config.max_num_batched_tokens
167-
max_num_seq = vllm_config.scheduler_config.max_num_seqs
168-
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
169-
return max_model_len * max_num_seq
169+
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
170+
# May be tuned later.
171+
return max_model_len * 2
172+
173+
174+
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
175+
max_prefill_buffer_size: int,
176+
reqs_start: int) -> list[tuple[int, int]]:
177+
"""
178+
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
179+
such that the total sequence length of each chunk is less than the
180+
maximum prefill buffer size.
181+
182+
Args:
183+
seq_lens_cpu: The sequence lengths of the prefill requests.
184+
max_prefill_buffer_size: The maximum prefill buffer size.
185+
reqs_start: The start index of the prefill requests.
186+
187+
Returns:
188+
A list of tuples of (reqs_start, reqs_end).
189+
"""
190+
chunk_seq_ids = []
191+
total_seq_lens = 0
192+
for i in range(reqs_start, len(seq_lens_cpu)):
193+
cur_seq_len = seq_lens_cpu[i].item()
194+
assert cur_seq_len <= max_prefill_buffer_size
195+
total_seq_lens += cur_seq_len
196+
if total_seq_lens > max_prefill_buffer_size:
197+
chunk_seq_ids.append((reqs_start, i))
198+
reqs_start = i
199+
total_seq_lens = cur_seq_len
200+
if total_seq_lens > 0:
201+
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
202+
return chunk_seq_ids
170203

171204

172205
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
@@ -201,6 +234,33 @@ def __init__(self, *args, **kwargs):
201234
dtype=torch.int32,
202235
device=self.device)
203236

237+
def build_one_prefill_chunk(self, reqs_start, reqs_end,
238+
query_start_loc_cpu, seq_lens_cpu,
239+
block_table):
240+
prefill_query_start_loc = query_start_loc_cpu[
241+
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
242+
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
243+
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
244+
self.device)
245+
token_start = query_start_loc_cpu[reqs_start].item()
246+
token_end = query_start_loc_cpu[reqs_end].item()
247+
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
248+
assert total_seq_lens <= self.max_prefill_buffer_size
249+
cu_seq_lens = torch.cat([
250+
torch.zeros(1, dtype=torch.int32),
251+
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
252+
]).to(torch.int32).to(self.device)
253+
return DeepseekV32IndexerPrefillChunkMetadata(
254+
cu_seqlen_ks=cu_seqlen_ks,
255+
cu_seqlen_ke=cu_seqlen_ke,
256+
cu_seq_lens=cu_seq_lens,
257+
total_seq_lens=total_seq_lens,
258+
block_table=block_table[reqs_start:reqs_end],
259+
token_start=token_start,
260+
token_end=token_end,
261+
num_reqs=reqs_end - reqs_start,
262+
)
263+
204264
def build(self,
205265
common_prefix_len: int,
206266
common_attn_metadata: CommonAttentionMetadata,
@@ -209,11 +269,7 @@ def build(self,
209269
num_reqs = common_attn_metadata.num_reqs
210270
num_tokens = common_attn_metadata.num_actual_tokens
211271

212-
device = self.device
213-
block_table_tensor = common_attn_metadata.block_table_tensor
214-
215-
query_start_loc = common_attn_metadata.query_start_loc
216-
272+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
217273
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
218274
split_decodes_and_prefills(
219275
common_attn_metadata,
@@ -224,27 +280,20 @@ def build(self,
224280

225281
prefill_metadata = None
226282
if num_prefills > 0:
227-
reqs_start = num_decodes
228-
prefill_query_start_loc = query_start_loc[
229-
reqs_start:] - query_start_loc[reqs_start]
230-
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
231-
prefill_query_start_loc,
232-
common_attn_metadata.seq_lens[reqs_start:])
233-
total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
234-
assert total_seq_lens < self.max_prefill_buffer_size
235-
cu_seq_lens = torch.cat([
236-
torch.zeros(1, dtype=torch.int32, device=device),
237-
common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
238-
]).to(torch.int32).cuda()
239-
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
240-
block_table=block_table_tensor[reqs_start:, ...],
241-
query_start_loc=prefill_query_start_loc,
242-
max_query_len=common_attn_metadata.max_query_len,
243-
cu_seqlen_ks=cu_seqlen_ks,
244-
cu_seqlen_ke=cu_seqlen_ke,
245-
cu_seq_lens=cu_seq_lens,
246-
total_seq_lens=total_seq_lens,
283+
chunk_seq_ids = split_prefill_chunks(
284+
common_attn_metadata.seq_lens_cpu,
285+
self.max_prefill_buffer_size,
286+
num_decodes,
247287
)
288+
chunks = [
289+
self.build_one_prefill_chunk(
290+
reqs_start, reqs_end, query_start_loc_cpu,
291+
common_attn_metadata.seq_lens_cpu,
292+
common_attn_metadata.block_table_tensor)
293+
for reqs_start, reqs_end in chunk_seq_ids
294+
]
295+
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
296+
chunks=chunks, )
248297

249298
decode_metadata = None
250299
if num_decodes > 0:

0 commit comments

Comments
 (0)