@@ -314,14 +314,13 @@ def forward(
314
314
class PagedAttentionWithALiBi (PagedAttention ):
315
315
"""PagedAttention with ALiBi attention bias."""
316
316
317
- def __init__ (
318
- self ,
319
- num_heads : int ,
320
- head_size : int ,
321
- scale : float ,
322
- slopes : List [float ],
323
- ) -> None :
324
- super ().__init__ (num_heads , head_size , scale )
317
+ def __init__ (self ,
318
+ num_heads : int ,
319
+ head_size : int ,
320
+ scale : float ,
321
+ slopes : List [float ],
322
+ num_kv_heads : Optional [int ] = None ) -> None :
323
+ super ().__init__ (num_heads , head_size , scale , num_kv_heads )
325
324
assert len (slopes ) == num_heads
326
325
327
326
slopes = torch .tensor (slopes , dtype = torch .float32 )
@@ -334,6 +333,11 @@ def set_attn_bias(self, input_metadata: InputMetadata) -> None:
334
333
# Generates ALiBi mask for each prompt.
335
334
for prompt_len in input_metadata .prompt_lens :
336
335
bias = torch .arange (prompt_len )
336
+ # Note(zhuohan): HF uses
337
+ # `bias = bias[None, :].repeat(prompt_len, 1)`
338
+ # here. We find that both biases give the same results, but
339
+ # the bias below more accurately follows the original ALiBi
340
+ # paper.
337
341
bias = bias [None , :] - bias [:, None ]
338
342
bias = bias .to (self .alibi_slopes .device )
339
343
@@ -363,10 +367,17 @@ def multi_query_kv_attention(
363
367
Args:
364
368
output: shape = [num_prompt_tokens, num_heads, head_size]
365
369
query: shape = [num_prompt_tokens, num_heads, head_size]
366
- key: shape = [num_prompt_tokens, num_heads , head_size]
367
- value: shape = [num_prompt_tokens, num_heads , head_size]
370
+ key: shape = [num_prompt_tokens, num_kv_heads , head_size]
371
+ value: shape = [num_prompt_tokens, num_kv_heads , head_size]
368
372
input_metadata: metadata for paged attention.
369
373
"""
374
+ if self .num_kv_heads != self .num_heads :
375
+ # Project the key and value tensors to the desired number of heads.
376
+ key = torch .repeat_interleave (key , self .num_queries_per_kv , dim = 1 )
377
+ value = torch .repeat_interleave (value ,
378
+ self .num_queries_per_kv ,
379
+ dim = 1 )
380
+
370
381
# FIXME(woosuk): Because xformers does not support dynamic sequence
371
382
# lengths with custom attention bias, we process each prompt one by
372
383
# one. This is inefficient, especially when we have many short prompts.
@@ -400,9 +411,10 @@ def single_query_cached_kv_attention(
400
411
Args:
401
412
output: shape = [num_generation_tokens, num_heads, head_size]
402
413
query: shape = [num_generation_tokens, num_heads, head_size]
403
- key_cache: shape = [num_blocks, num_heads , head_size/x,
414
+ key_cache: shape = [num_blocks, num_kv_heads , head_size/x,
404
415
block_size, x]
405
- value_cache: shape = [num_blocks, num_heads, head_size, block_size]
416
+ value_cache: shape = [num_blocks, num_kv_heads, head_size,
417
+ block_size]
406
418
input_metadata: metadata for paged attention.
407
419
"""
408
420
block_size = value_cache .shape [3 ]
0 commit comments