@@ -120,7 +120,7 @@ class AscendAttentionState(Enum):
120
120
@dataclass
121
121
class AscendMetadata :
122
122
123
- # **************************** Basic Properties ****************************
123
+ # **************************** Basic Properties ************************** #
124
124
attn_mask : Optional [torch .Tensor ] = None
125
125
# Current state of this attention run.
126
126
attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
@@ -138,7 +138,7 @@ class AscendMetadata:
138
138
# Maximum query length in the batch (None for decoding).
139
139
max_query_len : Optional [int ] = None
140
140
141
- # ********************** KV Cache Related Properties ***********************
141
+ # ********************** KV Cache Related Properties ********************* #
142
142
# Block addresses per sequence (Seq id -> list of physical block).
143
143
# (batch_size, max_blocks_per_seq)
144
144
block_tables : torch .Tensor = None
@@ -150,6 +150,7 @@ class AscendMetadata:
150
150
# (num_tokens,)
151
151
slot_mapping : torch .Tensor = None
152
152
153
+ # *************************** Other Properties *************************** #
153
154
enable_dbo_across_dp : bool = False
154
155
is_only_prefill : bool = False
155
156
@@ -245,6 +246,144 @@ def __init__(
245
246
self .key_cache = None
246
247
self .value_cache = None
247
248
249
+ def _forward_prefill_no_cache (
250
+ self ,
251
+ query : torch .Tensor ,
252
+ key : torch .Tensor ,
253
+ value : torch .Tensor ,
254
+ attn_metadata : AscendMetadata ,
255
+ output : Optional [torch .Tensor ] = None ,
256
+ num_tokens = 0 ,
257
+ ) -> torch .Tensor :
258
+ assert attn_metadata is not None
259
+ assert attn_metadata .attn_mask is not None
260
+
261
+ mask = attn_metadata .attn_mask
262
+
263
+ if is_310p ():
264
+ # align q k v output tensors
265
+ query = aligned_16 (query )
266
+ key = aligned_16 (key )
267
+ value = aligned_16 (value )
268
+ output = aligned_16 (output )
269
+ # do reformat in case of broadcasted tensors
270
+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
271
+ mask = torch_npu .npu_format_cast (mask .contiguous (),
272
+ ACL_FORMAT_FRACTAL_NZ )
273
+
274
+ torch_npu ._npu_flash_attention (query = query ,
275
+ key = key ,
276
+ value = value ,
277
+ mask = mask ,
278
+ seq_len = attn_metadata .seq_lens ,
279
+ scale_value = self .scale ,
280
+ num_heads = self .num_heads ,
281
+ num_kv_heads = self .num_kv_heads ,
282
+ out = output )
283
+ assert output is not None
284
+ return output [:num_tokens , :, :]
285
+
286
+ def _forward_prefill_cache_hit (
287
+ self ,
288
+ query : torch .Tensor ,
289
+ attn_metadata : AscendMetadata ,
290
+ output : Optional [torch .Tensor ] = None ,
291
+ ) -> torch .Tensor :
292
+ assert attn_metadata is not None
293
+ assert attn_metadata .attn_mask is not None
294
+
295
+ compress_mask = attn_metadata .attn_mask
296
+ batch_size = attn_metadata .query_lens .shape [0 ]
297
+ block_table = attn_metadata .block_tables [:batch_size , :]
298
+
299
+ torch_npu ._npu_flash_attention_qlens (
300
+ query = query ,
301
+ key_cache = self .key_cache ,
302
+ value_cache = self .value_cache ,
303
+ block_table = block_table ,
304
+ mask = compress_mask ,
305
+ seq_len = attn_metadata .query_lens ,
306
+ context_lens = attn_metadata .seq_lens ,
307
+ num_kv_heads = self .num_kv_heads ,
308
+ num_heads = self .num_heads ,
309
+ scale_value = self .scale ,
310
+ out = output )
311
+ return output
312
+
313
+ def _forward_decode_only (
314
+ self ,
315
+ query : torch .Tensor ,
316
+ attn_metadata : AscendMetadata ,
317
+ output : Optional [torch .Tensor ] = None ,
318
+ ) -> torch .Tensor :
319
+ if is_310p ():
320
+ # seq_lens_tensor needs to be transferred to the device for 310P.
321
+ attn_metadata .seq_lens = \
322
+ attn_metadata .seq_lens .to (device = query .device )
323
+
324
+ torch_npu ._npu_paged_attention (query = query ,
325
+ key_cache = self .key_cache ,
326
+ value_cache = self .value_cache ,
327
+ num_kv_heads = self .num_kv_heads ,
328
+ num_heads = self .num_heads ,
329
+ scale_value = self .scale ,
330
+ block_table = attn_metadata .block_tables ,
331
+ context_lens = attn_metadata .seq_lens ,
332
+ out = output )
333
+ return output
334
+
335
+ def _forward_v1_style (
336
+ self ,
337
+ query : torch .Tensor ,
338
+ attn_metadata : AscendMetadata ,
339
+ output : Optional [torch .Tensor ] = None ,
340
+ ) -> torch .Tensor :
341
+ # Use chunked prefill for head size 192 scenario, like deepseek
342
+ # paged_attention_splitfuse maybe crash at such scenario.
343
+ # TODO: vanilla path will be removed after the kernel support
344
+ # head_size 192 scenario.
345
+ if self .head_size == 192 :
346
+ cu_seqlen_q = [0 ] + attn_metadata .query_lens .tolist ()
347
+ cu_seqlen_k = [0 ] + attn_metadata .seq_lens .tolist ()
348
+ cu_seqlen_q = torch .tensor (cu_seqlen_q , device = query .device )
349
+ cu_seqlen_k = torch .tensor (cu_seqlen_k , device = query .device )
350
+ cu_seqlen_q = torch .cumsum (cu_seqlen_q , dim = 0 )
351
+ cu_seqlen_k = torch .cumsum (cu_seqlen_k , dim = 0 )
352
+ max_seqlen_q = torch .max (attn_metadata .query_lens )
353
+ max_seqlen_k = torch .max (attn_metadata .seq_lens )
354
+ vanilla_chunked_prefill (output , query , self .key_cache ,
355
+ self .value_cache ,
356
+ attn_metadata .block_tables , cu_seqlen_q ,
357
+ cu_seqlen_k , max_seqlen_q , max_seqlen_k ,
358
+ self .scale , None , True )
359
+ return output
360
+
361
+ # Use paged attention.
362
+ assert attn_metadata is not None
363
+ assert attn_metadata .attn_mask is not None
364
+
365
+ if is_310p ():
366
+ # Do reformat in case of broadcasted tensors.
367
+ attn_metadata .attn_mask = \
368
+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (),
369
+ ACL_FORMAT_FRACTAL_NZ )
370
+ attn_metadata .seq_lens = \
371
+ attn_metadata .seq_lens .to (device = query .device )
372
+
373
+ torch_npu ._npu_paged_attention_splitfuse (
374
+ query = query ,
375
+ key_cache = self .key_cache ,
376
+ value_cache = self .value_cache ,
377
+ mask = attn_metadata .attn_mask ,
378
+ block_table = attn_metadata .block_tables ,
379
+ seq_len = attn_metadata .query_lens ,
380
+ context_lens = attn_metadata .seq_lens ,
381
+ num_kv_heads = self .num_kv_heads ,
382
+ num_heads = self .num_heads ,
383
+ scale_value = self .scale ,
384
+ out = output )
385
+ return output
386
+
248
387
def forward (
249
388
self ,
250
389
layer : AttentionLayer ,
@@ -325,109 +464,18 @@ def forward(
325
464
326
465
# V0-Style scheduler situation.
327
466
if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
328
- assert attn_metadata is not None
329
- assert attn_metadata .attn_mask is not None
330
- mask = attn_metadata .attn_mask
331
- if is_310p ():
332
- # align q k v output tensors
333
- query = aligned_16 (query )
334
- key = aligned_16 (key )
335
- value = aligned_16 (value )
336
- output = aligned_16 (output )
337
-
338
- # do reformat in case of broadcasted tensors
339
- mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
340
- mask = torch_npu .npu_format_cast (mask .contiguous (),
341
- ACL_FORMAT_FRACTAL_NZ )
342
-
343
- torch_npu ._npu_flash_attention (query = query ,
344
- key = key ,
345
- value = value ,
346
- mask = mask ,
347
- seq_len = attn_metadata .seq_lens ,
348
- scale_value = self .scale ,
349
- num_heads = self .num_heads ,
350
- num_kv_heads = self .num_kv_heads ,
351
- out = output )
352
- output = output [:num_tokens , :, :]
353
- elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
354
- assert attn_metadata is not None
355
- assert attn_metadata .attn_mask is not None
356
- compress_mask = attn_metadata .attn_mask
357
- batch_size = attn_metadata .query_lens .shape [0 ]
358
- block_table = attn_metadata .block_tables [:batch_size , :]
359
- torch_npu ._npu_flash_attention_qlens (
360
- query = query ,
361
- key_cache = self .key_cache ,
362
- value_cache = self .value_cache ,
363
- block_table = block_table ,
364
- mask = compress_mask ,
365
- seq_len = attn_metadata .query_lens ,
366
- context_lens = attn_metadata .seq_lens ,
367
- num_kv_heads = self .num_kv_heads ,
368
- num_heads = self .num_heads ,
369
- scale_value = self .scale ,
370
- out = output )
467
+ output = self ._forward_prefill_no_cache (
468
+ query , key , value , attn_metadata , output , num_tokens )
469
+ elif attn_metadata .attn_state == \
470
+ AscendAttentionState .PrefillCacheHit :
471
+ output = self ._forward_prefill_cache_hit (
472
+ query , attn_metadata , output )
371
473
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
372
- if is_310p ():
373
- # # seq_lens_tensor needs to be transferred to the device for 310P
374
- attn_metadata .seq_lens = \
375
- attn_metadata .seq_lens .to (device = query .device )
376
- torch_npu ._npu_paged_attention (
377
- query = query ,
378
- key_cache = self .key_cache ,
379
- value_cache = self .value_cache ,
380
- num_kv_heads = self .num_kv_heads ,
381
- num_heads = self .num_heads ,
382
- scale_value = self .scale ,
383
- block_table = attn_metadata .block_tables ,
384
- context_lens = attn_metadata .seq_lens ,
385
- out = output )
474
+ output = self ._forward_decode_only (query , attn_metadata ,
475
+ output )
386
476
# Normal V1 situation.
387
477
else :
388
- # use chunked prefill for head size 192 scenario, like deepseek
389
- # paged_attention_splitfuse maybe crash at such scenario
390
- # TODO: vanilla path will be removed after the kernel support
391
- # head_size 192 scenario
392
- if self .head_size == 192 :
393
- cu_seqlen_q = [0 ] + attn_metadata .query_lens .tolist ()
394
- cu_seqlen_k = [0 ] + attn_metadata .seq_lens .tolist ()
395
- cu_seqlen_q = torch .tensor (cu_seqlen_q ,
396
- device = query .device )
397
- cu_seqlen_k = torch .tensor (cu_seqlen_k ,
398
- device = query .device )
399
- cu_seqlen_q = torch .cumsum (cu_seqlen_q , dim = 0 )
400
- cu_seqlen_k = torch .cumsum (cu_seqlen_k , dim = 0 )
401
- max_seqlen_q = torch .max (attn_metadata .query_lens )
402
- max_seqlen_k = torch .max (attn_metadata .seq_lens )
403
- vanilla_chunked_prefill (output , query , self .key_cache ,
404
- self .value_cache ,
405
- attn_metadata .block_tables ,
406
- cu_seqlen_q , cu_seqlen_k ,
407
- max_seqlen_q , max_seqlen_k ,
408
- self .scale , None , True )
409
- else :
410
- # use paged attention
411
- assert attn_metadata is not None
412
- assert attn_metadata .attn_mask is not None
413
- if is_310p ():
414
- # do reformat in case of broadcasted tensors
415
- attn_metadata .attn_mask = \
416
- torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (), ACL_FORMAT_FRACTAL_NZ )
417
- attn_metadata .seq_lens = \
418
- attn_metadata .seq_lens .to (device = query .device )
419
- torch_npu ._npu_paged_attention_splitfuse (
420
- query = query ,
421
- key_cache = self .key_cache ,
422
- value_cache = self .value_cache ,
423
- mask = attn_metadata .attn_mask ,
424
- block_table = attn_metadata .block_tables ,
425
- seq_len = attn_metadata .query_lens ,
426
- context_lens = attn_metadata .seq_lens ,
427
- num_kv_heads = self .num_kv_heads ,
428
- num_heads = self .num_heads ,
429
- scale_value = self .scale ,
430
- out = output )
478
+ output = self ._forward_v1_style (query , attn_metadata , output )
431
479
432
480
# to make in-place change to the output tensor
433
481
if hasattr (layer , 'quant_method' ) and use_kv_cache_int8 :
0 commit comments