@@ -317,7 +317,7 @@ def forward(
317317 # normal attention
318318 # When block_tables are not filled, it means q and k are the
319319 # prompt, and they have the same length.
320- flash_attn_varlen_func (
320+ out = flash_attn_varlen_func (
321321 q = query ,
322322 k = key ,
323323 v = value ,
@@ -329,13 +329,14 @@ def forward(
329329 causal = True ,
330330 window_size = self .sliding_window ,
331331 alibi_slopes = self .alibi_slopes ,
332- out = output [:num_prefill_tokens ],
333332 )
333+ assert output [:num_prefill_tokens ].shape == out .shape
334+ output [:num_prefill_tokens ] = out
334335 else :
335336 # prefix-enabled attention
336337 assert prefill_meta .seq_lens is not None
337338 max_seq_len = max (prefill_meta .seq_lens )
338- flash_attn_varlen_func (
339+ output [: num_prefill_tokens ] = flash_attn_varlen_func (
339340 q = query ,
340341 k = key_cache ,
341342 v = value_cache ,
@@ -347,12 +348,11 @@ def forward(
347348 causal = True ,
348349 alibi_slopes = self .alibi_slopes ,
349350 block_table = prefill_meta .block_tables ,
350- out = output [:num_prefill_tokens ],
351351 )
352352
353353 if decode_meta := attn_metadata .decode_metadata :
354354 # Decoding run.
355- flash_attn_with_kvcache (
355+ output [ num_prefill_tokens :] = flash_attn_with_kvcache (
356356 decode_query .unsqueeze (1 ),
357357 key_cache ,
358358 value_cache ,
@@ -361,8 +361,7 @@ def forward(
361361 softmax_scale = self .scale ,
362362 causal = True ,
363363 alibi_slopes = self .alibi_slopes ,
364- out = output [num_prefill_tokens :].unsqueeze (1 ),
365- )
364+ ).squeeze (1 )
366365
367366 # Reshape the output tensor.
368367 return output .view (num_tokens , hidden_size )
0 commit comments