@@ -998,7 +998,7 @@ def _forward_decode(
998
998
decode_meta = attn_metadata .decode
999
999
assert decode_meta is not None
1000
1000
num_tokens = q_nope .size (0 )
1001
- if self .running_in_graph :
1001
+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
1002
1002
# shape of knope/k_pe for npu graph mode should be:
1003
1003
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
1004
1004
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -1112,6 +1112,7 @@ def forward(
1112
1112
self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
1113
1113
AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1114
1114
]
1115
+ self .running_chunkprefilll_with_torchair = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill
1115
1116
num_actual_toks = attn_metadata .num_actual_tokens
1116
1117
if k_pe is None and not self .running_in_graph :
1117
1118
kv_c , k_pe = self .kv_a_proj_with_mqa (
@@ -1148,18 +1149,25 @@ def forward(
1148
1149
if has_decode :
1149
1150
decode_k_nope = None
1150
1151
assert attn_metadata .decode is not None
1151
- if self .running_in_graph :
1152
+ if self .running_in_graph or self . running_chunkprefilll_with_torchair :
1152
1153
cos = attn_metadata .decode .cos
1153
1154
sin = attn_metadata .decode .sin
1154
- with npu_stream_switch ("mla_secondary" ,
1155
- 0 ,
1156
- enabled = enable_multistream_mla ):
1157
- npu_wait_tensor (hidden_states_or_kv_c_normed ,
1158
- ckq ,
1159
- enabled = enable_multistream_mla )
1155
+ if self .running_chunkprefilll_with_torchair :
1156
+ decode_hs = (
1157
+ hidden_states_or_kv_c_normed [:num_decode_tokens ])
1158
+ slots = attn_metadata .slot_mapping [:num_decode_tokens ]
1160
1159
decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1161
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1162
- attn_metadata .slot_mapping )
1160
+ decode_hs , cos , sin , kv_cache , slots )
1161
+ else :
1162
+ with npu_stream_switch ("mla_secondary" ,
1163
+ 0 ,
1164
+ enabled = enable_multistream_mla ):
1165
+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1166
+ ckq ,
1167
+ enabled = enable_multistream_mla )
1168
+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1169
+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1170
+ attn_metadata .slot_mapping )
1163
1171
# Without explicitly controlling the order, IndexByTensor operations
1164
1172
# would be placed after `matmul W_KV_T` hindering the overlapping of
1165
1173
# KvRmsNormRopeCache and SingleRope.
@@ -1183,6 +1191,8 @@ def forward(
1183
1191
decode_k_pe ,
1184
1192
enabled = enable_multistream_mla )
1185
1193
decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1194
+ elif self .running_chunkprefilll_with_torchair :
1195
+ decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1186
1196
else :
1187
1197
decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
1188
1198
attn_metadata .decode .input_positions ,
@@ -1221,16 +1231,15 @@ def forward(
1221
1231
kv_cache
1222
1232
) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1223
1233
if self .torchair_graph_enabled :
1224
- if kv_cache [0 ].numel (
1225
- ) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1234
+ if kv_cache [0 ].numel () > 0 and has_prefill :
1226
1235
slots = attn_metadata .slot_mapping
1227
1236
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
1228
- torch_npu ._npu_reshape_and_cache (key = kv_c_normed . view (
1229
- num_tokens , self .num_kv_heads , - 1 ),
1230
- value = prefill_k_pe ,
1231
- key_cache = kv_cache [0 ],
1232
- value_cache = kv_cache [1 ],
1233
- slot_indices = slots )
1237
+ torch_npu ._npu_reshape_and_cache (
1238
+ key = kv_c_normed . view ( num_tokens , self .num_kv_heads , - 1 ),
1239
+ value = prefill_k_pe ,
1240
+ key_cache = kv_cache [0 ],
1241
+ value_cache = kv_cache [1 ],
1242
+ slot_indices = slots [ num_decode_tokens :] )
1234
1243
else :
1235
1244
kv_c_normed = kv_c_normed .view (
1236
1245
[num_actual_toks , self .num_kv_heads , - 1 ])
0 commit comments