1
1
"""Multi-head attention."""
2
- from typing import Optional
2
+ from typing import List , Optional
3
3
4
4
import torch
5
5
import torch .nn as nn
6
6
from xformers import ops as xops
7
+ from xformers .ops .fmha .attn_bias import (BlockDiagonalCausalMask ,
8
+ LowerTriangularMaskWithTensorBias )
7
9
8
10
from vllm import attention_ops
9
11
from vllm import cache_ops
@@ -53,13 +55,21 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
53
55
raise ValueError (f"head_size ({ self .head_size } ) is not supported. "
54
56
f"Supported head sizes: { _SUPPORTED_HEAD_SIZES } ." )
55
57
58
+ def set_attn_bias (self , input_metadata : InputMetadata ) -> None :
59
+ if input_metadata .attn_bias :
60
+ # Already set by a previous layer.
61
+ return
62
+ prompt_lens = input_metadata .prompt_lens
63
+ attn_bias = BlockDiagonalCausalMask .from_seqlens (prompt_lens )
64
+ input_metadata .attn_bias .append (attn_bias )
65
+
56
66
def multi_query_kv_attention (
57
67
self ,
58
68
output : torch .Tensor ,
59
69
query : torch .Tensor ,
60
70
key : torch .Tensor ,
61
71
value : torch .Tensor ,
62
- attn_bias : xops . AttentionBias ,
72
+ input_metadata : InputMetadata ,
63
73
) -> torch .Tensor :
64
74
"""Normal attention for the prompt tokens.
65
75
@@ -68,13 +78,14 @@ def multi_query_kv_attention(
68
78
query: shape = [num_prompt_tokens, num_heads, head_size]
69
79
key: shape = [num_prompt_tokens, num_heads, head_size]
70
80
value: shape = [num_prompt_tokens, num_heads, head_size]
81
+ input_metadata: metadata for paged attention.
71
82
"""
72
83
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
73
84
out = xops .memory_efficient_attention_forward (
74
85
query .unsqueeze (0 ),
75
86
key .unsqueeze (0 ),
76
87
value .unsqueeze (0 ),
77
- attn_bias = attn_bias ,
88
+ attn_bias = input_metadata . attn_bias [ 0 ] ,
78
89
p = 0.0 ,
79
90
scale = self .scale ,
80
91
op = self .attn_op ,
@@ -112,6 +123,7 @@ def single_query_cached_kv_attention(
112
123
input_metadata .context_lens ,
113
124
block_size ,
114
125
input_metadata .max_context_len ,
126
+ None , # alibi_slopes
115
127
)
116
128
117
129
def forward (
@@ -154,12 +166,13 @@ def forward(
154
166
# Compute the attention op for prompts.
155
167
num_prompt_tokens = input_metadata .num_prompt_tokens
156
168
if num_prompt_tokens > 0 :
169
+ self .set_attn_bias (input_metadata )
157
170
self .multi_query_kv_attention (
158
171
output [:num_prompt_tokens ],
159
172
query [:num_prompt_tokens ],
160
173
key [:num_prompt_tokens ],
161
174
value [:num_prompt_tokens ],
162
- input_metadata . attn_bias ,
175
+ input_metadata ,
163
176
)
164
177
165
178
# Wait until the cache op is done.
@@ -219,7 +232,8 @@ def __init__(
219
232
cache = torch .cat ((cos , sin ), dim = - 1 )
220
233
221
234
# FIXME(woosuk): This assumes that we configure the default dtype when
222
- # initializing the model. Make it more robust.
235
+ # initializing the model.
236
+ # TODO(woosuk): Make it more robust.
223
237
torch_dtype = torch .get_default_dtype ()
224
238
cache = cache .to (torch_dtype )
225
239
# Embedding size: [max_position, rotary_dim]
@@ -271,3 +285,112 @@ def forward(
271
285
input_metadata ,
272
286
cache_event ,
273
287
)
288
+
289
+
290
+ class PagedAttentionWithALiBi (PagedAttention ):
291
+ """PagedAttention with ALiBi attention bias."""
292
+
293
+ def __init__ (
294
+ self ,
295
+ num_heads : int ,
296
+ head_size : int ,
297
+ scale : float ,
298
+ slopes : List [float ],
299
+ ) -> None :
300
+ super ().__init__ (num_heads , head_size , scale )
301
+ assert len (slopes ) == num_heads
302
+
303
+ slopes = torch .tensor (slopes , dtype = torch .float32 )
304
+ self .register_buffer ("alibi_slopes" , slopes , persistent = False )
305
+
306
+ def set_attn_bias (self , input_metadata : InputMetadata ) -> None :
307
+ if input_metadata .attn_bias :
308
+ # Already set by a previous layer.
309
+ return
310
+ # Generates ALiBi mask for each prompt.
311
+ for prompt_len in input_metadata .prompt_lens :
312
+ bias = torch .arange (prompt_len )
313
+ bias = bias [None , :] - bias [:, None ]
314
+ bias = bias .to (self .alibi_slopes .device )
315
+
316
+ # When using custom attention bias, xformers requires the bias to
317
+ # be sliced from a tensor whose length is a multiple of 8.
318
+ padded_len = (prompt_len + 7 ) // 8 * 8
319
+ bias = torch .empty (
320
+ self .num_heads ,
321
+ padded_len ,
322
+ padded_len ,
323
+ device = self .alibi_slopes .device ,
324
+ )[:, :prompt_len , :prompt_len ].copy_ (bias )
325
+ bias .mul_ (self .alibi_slopes [:, None , None ])
326
+ attn_bias = LowerTriangularMaskWithTensorBias (bias )
327
+ input_metadata .attn_bias .append (attn_bias )
328
+
329
+ def multi_query_kv_attention (
330
+ self ,
331
+ output : torch .Tensor ,
332
+ query : torch .Tensor ,
333
+ key : torch .Tensor ,
334
+ value : torch .Tensor ,
335
+ input_metadata : InputMetadata ,
336
+ ) -> torch .Tensor :
337
+ """Attention with ALiBi bias for the prompt tokens.
338
+
339
+ Args:
340
+ output: shape = [num_prompt_tokens, num_heads, head_size]
341
+ query: shape = [num_prompt_tokens, num_heads, head_size]
342
+ key: shape = [num_prompt_tokens, num_heads, head_size]
343
+ value: shape = [num_prompt_tokens, num_heads, head_size]
344
+ input_metadata: metadata for paged attention.
345
+ """
346
+ # FIXME(woosuk): Because xformers does not support dynamic sequence
347
+ # lengths with custom attention bias, we process each prompt one by
348
+ # one. This is inefficient, especially when we have many short prompts.
349
+ start = 0
350
+ for i , prompt_len in enumerate (input_metadata .prompt_lens ):
351
+ end = start + prompt_len
352
+ out = xops .memory_efficient_attention_forward (
353
+ query [None , start :end ],
354
+ key [None , start :end ],
355
+ value [None , start :end ],
356
+ attn_bias = input_metadata .attn_bias [i ],
357
+ p = 0.0 ,
358
+ scale = self .scale ,
359
+ op = self .attn_op ,
360
+ )
361
+ # TODO(woosuk): Unnecessary copy. Optimize.
362
+ output [start :end ].copy_ (out .squeeze (0 ))
363
+ start += prompt_len
364
+ return output
365
+
366
+ def single_query_cached_kv_attention (
367
+ self ,
368
+ output : torch .Tensor ,
369
+ query : torch .Tensor ,
370
+ key_cache : torch .Tensor ,
371
+ value_cache : torch .Tensor ,
372
+ input_metadata : InputMetadata ,
373
+ ) -> None :
374
+ """PagedAttention with ALiBi bias for the generation tokens.
375
+
376
+ Args:
377
+ output: shape = [num_generation_tokens, num_heads, head_size]
378
+ query: shape = [num_generation_tokens, num_heads, head_size]
379
+ key_cache: shape = [num_blocks, num_heads, head_size/x,
380
+ block_size, x]
381
+ value_cache: shape = [num_blocks, num_heads, head_size, block_size]
382
+ input_metadata: metadata for paged attention.
383
+ """
384
+ block_size = value_cache .shape [3 ]
385
+ attention_ops .single_query_cached_kv_attention (
386
+ output ,
387
+ query ,
388
+ key_cache ,
389
+ value_cache ,
390
+ self .scale ,
391
+ input_metadata .block_tables ,
392
+ input_metadata .context_lens ,
393
+ block_size ,
394
+ input_metadata .max_context_len ,
395
+ self .alibi_slopes ,
396
+ )
0 commit comments