Skip to content

Commit e41f067

Browse files
authored
Add support for BLOOM (#331)
1 parent d6fa1be commit e41f067

File tree

11 files changed

+479
-18
lines changed

11 files changed

+479
-18
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ vLLM is flexible and easy to use with:
4141

4242
vLLM seamlessly supports many Huggingface models, including the following architectures:
4343

44+
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
4445
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
4546
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
4647
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)

csrc/attention.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/extension.h>
2+
#include <c10/util/Optional.h>
23

34
void single_query_cached_kv_attention(
45
torch::Tensor& out,
@@ -9,7 +10,8 @@ void single_query_cached_kv_attention(
910
torch::Tensor& block_tables,
1011
torch::Tensor& context_lens,
1112
int block_size,
12-
int max_context_len);
13+
int max_context_len,
14+
const c10::optional<torch::Tensor>& alibi_slopes);
1315

1416
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1517
m.def(

csrc/attention/attention_kernels.cu

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel(
8080
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
8181
const int* __restrict__ context_lens, // [num_seqs]
8282
const int max_num_blocks_per_seq,
83+
const float* __restrict__ alibi_slopes, // [num_heads]
8384
const int q_stride) {
8485
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
8586
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
@@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel(
9192
const int head_idx = blockIdx.x;
9293
const int num_heads = gridDim.x;
9394
const int seq_idx = blockIdx.y;
95+
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
9496

9597
// A vector type to store a part of a key or a query.
9698
// The vector size is configured in such a way that the threads in a thread group
@@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel(
167169

168170
// Compute dot product.
169171
// This includes a reduction across the threads in the same thread group.
170-
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
171-
const bool mask = token_idx >= context_len;
172-
172+
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
173+
// Add the ALiBi bias if slopes are given.
174+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
175+
173176
if (thread_group_offset == 0) {
174177
// Store the partial reductions to shared memory.
175178
// NOTE(woosuk): It is required to zero out the masked logits.
179+
const bool mask = token_idx >= context_len;
176180
logits[token_idx] = mask ? 0.f : qk;
177181
// Update the max value.
178182
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel(
328332
block_tables_ptr, \
329333
context_lens_ptr, \
330334
max_num_blocks_per_seq, \
335+
alibi_slopes_ptr, \
331336
query_stride);
332337

333338
// TODO(woosuk): Tune NUM_THREADS.
@@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher(
343348
float scale,
344349
torch::Tensor& block_tables,
345350
torch::Tensor& context_lens,
346-
int max_context_len) {
351+
int max_context_len,
352+
const c10::optional<torch::Tensor>& alibi_slopes) {
347353
int num_seqs = query.size(0);
348354
int num_heads = query.size(1);
349355
int head_size = query.size(2);
@@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher(
353359
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
354360
assert(head_size % thread_group_size == 0);
355361

362+
// NOTE: alibi_slopes is optional.
363+
const float* alibi_slopes_ptr = alibi_slopes ?
364+
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
365+
: nullptr;
366+
356367
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
357368
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
358369
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
@@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher(
411422
scale, \
412423
block_tables, \
413424
context_lens, \
414-
max_context_len);
425+
max_context_len, \
426+
alibi_slopes);
415427

416428
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
417429
// 1, 2, 4, 64, 128, 256.
@@ -458,7 +470,8 @@ void single_query_cached_kv_attention(
458470
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
459471
torch::Tensor& context_lens, // [num_seqs]
460472
int block_size,
461-
int max_context_len) {
473+
int max_context_len,
474+
const c10::optional<torch::Tensor>& alibi_slopes) {
462475
if (query.dtype() == at::ScalarType::Float) {
463476
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
464477
} else if (query.dtype() == at::ScalarType::Half) {

docs/source/models/supported_models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it.
1414
* - Architecture
1515
- Models
1616
- Example HuggingFace Models
17+
* - :code:`BloomForCausalLM`
18+
- BLOOM, BLOOMZ, BLOOMChat
19+
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
1720
* - :code:`GPT2LMHeadModel`
1821
- GPT-2
1922
- :code:`gpt2`, :code:`gpt2-xl`, etc.

tests/kernels/test_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def run_single_query_cached_kv_attention(
216216
context_lens,
217217
block_size,
218218
max_context_len,
219+
None, # ALiBi slopes.
219220
)
220221

221222
ref_output = torch.empty_like(query)

vllm/model_executor/input_metadata.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, List, Tuple
22

33
import torch
4-
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
4+
from xformers.ops import AttentionBias
55

66
from vllm.sampling_params import SamplingParams
77
from vllm.sequence import SequenceData
@@ -38,7 +38,6 @@ def __init__(
3838
self.max_context_len = max_context_len
3939
self.block_tables = block_tables
4040

41-
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
4241
self.num_prompts = len(prompt_lens)
4342
self.num_prompt_tokens = sum(prompt_lens)
4443
self.num_generation_tokens = context_lens.shape[0]
@@ -50,6 +49,9 @@ def __init__(
5049
assert block_tables.shape[0] == self.num_generation_tokens
5150
assert context_lens.shape[0] == self.num_generation_tokens
5251

52+
# Set during the execution of the first attention op.
53+
self.attn_bias: List[AttentionBias] = []
54+
5355
def __repr__(self) -> str:
5456
# Print only useful metadata.
5557
return (f'InputMetadata('

vllm/model_executor/layers/attention.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Multi-head attention."""
2-
from typing import Optional
2+
from typing import List, Optional
33

44
import torch
55
import torch.nn as nn
66
from xformers import ops as xops
7+
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
8+
LowerTriangularMaskWithTensorBias)
79

810
from vllm import attention_ops
911
from vllm import cache_ops
@@ -53,13 +55,21 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
5355
raise ValueError(f"head_size ({self.head_size}) is not supported. "
5456
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
5557

58+
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
59+
if input_metadata.attn_bias:
60+
# Already set by a previous layer.
61+
return
62+
prompt_lens = input_metadata.prompt_lens
63+
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
64+
input_metadata.attn_bias.append(attn_bias)
65+
5666
def multi_query_kv_attention(
5767
self,
5868
output: torch.Tensor,
5969
query: torch.Tensor,
6070
key: torch.Tensor,
6171
value: torch.Tensor,
62-
attn_bias: xops.AttentionBias,
72+
input_metadata: InputMetadata,
6373
) -> torch.Tensor:
6474
"""Normal attention for the prompt tokens.
6575
@@ -68,13 +78,14 @@ def multi_query_kv_attention(
6878
query: shape = [num_prompt_tokens, num_heads, head_size]
6979
key: shape = [num_prompt_tokens, num_heads, head_size]
7080
value: shape = [num_prompt_tokens, num_heads, head_size]
81+
input_metadata: metadata for paged attention.
7182
"""
7283
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
7384
out = xops.memory_efficient_attention_forward(
7485
query.unsqueeze(0),
7586
key.unsqueeze(0),
7687
value.unsqueeze(0),
77-
attn_bias=attn_bias,
88+
attn_bias=input_metadata.attn_bias[0],
7889
p=0.0,
7990
scale=self.scale,
8091
op=self.attn_op,
@@ -112,6 +123,7 @@ def single_query_cached_kv_attention(
112123
input_metadata.context_lens,
113124
block_size,
114125
input_metadata.max_context_len,
126+
None, # alibi_slopes
115127
)
116128

117129
def forward(
@@ -154,12 +166,13 @@ def forward(
154166
# Compute the attention op for prompts.
155167
num_prompt_tokens = input_metadata.num_prompt_tokens
156168
if num_prompt_tokens > 0:
169+
self.set_attn_bias(input_metadata)
157170
self.multi_query_kv_attention(
158171
output[:num_prompt_tokens],
159172
query[:num_prompt_tokens],
160173
key[:num_prompt_tokens],
161174
value[:num_prompt_tokens],
162-
input_metadata.attn_bias,
175+
input_metadata,
163176
)
164177

165178
# Wait until the cache op is done.
@@ -219,7 +232,8 @@ def __init__(
219232
cache = torch.cat((cos, sin), dim=-1)
220233

221234
# FIXME(woosuk): This assumes that we configure the default dtype when
222-
# initializing the model. Make it more robust.
235+
# initializing the model.
236+
# TODO(woosuk): Make it more robust.
223237
torch_dtype = torch.get_default_dtype()
224238
cache = cache.to(torch_dtype)
225239
# Embedding size: [max_position, rotary_dim]
@@ -271,3 +285,112 @@ def forward(
271285
input_metadata,
272286
cache_event,
273287
)
288+
289+
290+
class PagedAttentionWithALiBi(PagedAttention):
291+
"""PagedAttention with ALiBi attention bias."""
292+
293+
def __init__(
294+
self,
295+
num_heads: int,
296+
head_size: int,
297+
scale: float,
298+
slopes: List[float],
299+
) -> None:
300+
super().__init__(num_heads, head_size, scale)
301+
assert len(slopes) == num_heads
302+
303+
slopes = torch.tensor(slopes, dtype=torch.float32)
304+
self.register_buffer("alibi_slopes", slopes, persistent=False)
305+
306+
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
307+
if input_metadata.attn_bias:
308+
# Already set by a previous layer.
309+
return
310+
# Generates ALiBi mask for each prompt.
311+
for prompt_len in input_metadata.prompt_lens:
312+
bias = torch.arange(prompt_len)
313+
bias = bias[None, :] - bias[:, None]
314+
bias = bias.to(self.alibi_slopes.device)
315+
316+
# When using custom attention bias, xformers requires the bias to
317+
# be sliced from a tensor whose length is a multiple of 8.
318+
padded_len = (prompt_len + 7) // 8 * 8
319+
bias = torch.empty(
320+
self.num_heads,
321+
padded_len,
322+
padded_len,
323+
device=self.alibi_slopes.device,
324+
)[:, :prompt_len, :prompt_len].copy_(bias)
325+
bias.mul_(self.alibi_slopes[:, None, None])
326+
attn_bias = LowerTriangularMaskWithTensorBias(bias)
327+
input_metadata.attn_bias.append(attn_bias)
328+
329+
def multi_query_kv_attention(
330+
self,
331+
output: torch.Tensor,
332+
query: torch.Tensor,
333+
key: torch.Tensor,
334+
value: torch.Tensor,
335+
input_metadata: InputMetadata,
336+
) -> torch.Tensor:
337+
"""Attention with ALiBi bias for the prompt tokens.
338+
339+
Args:
340+
output: shape = [num_prompt_tokens, num_heads, head_size]
341+
query: shape = [num_prompt_tokens, num_heads, head_size]
342+
key: shape = [num_prompt_tokens, num_heads, head_size]
343+
value: shape = [num_prompt_tokens, num_heads, head_size]
344+
input_metadata: metadata for paged attention.
345+
"""
346+
# FIXME(woosuk): Because xformers does not support dynamic sequence
347+
# lengths with custom attention bias, we process each prompt one by
348+
# one. This is inefficient, especially when we have many short prompts.
349+
start = 0
350+
for i, prompt_len in enumerate(input_metadata.prompt_lens):
351+
end = start + prompt_len
352+
out = xops.memory_efficient_attention_forward(
353+
query[None, start:end],
354+
key[None, start:end],
355+
value[None, start:end],
356+
attn_bias=input_metadata.attn_bias[i],
357+
p=0.0,
358+
scale=self.scale,
359+
op=self.attn_op,
360+
)
361+
# TODO(woosuk): Unnecessary copy. Optimize.
362+
output[start:end].copy_(out.squeeze(0))
363+
start += prompt_len
364+
return output
365+
366+
def single_query_cached_kv_attention(
367+
self,
368+
output: torch.Tensor,
369+
query: torch.Tensor,
370+
key_cache: torch.Tensor,
371+
value_cache: torch.Tensor,
372+
input_metadata: InputMetadata,
373+
) -> None:
374+
"""PagedAttention with ALiBi bias for the generation tokens.
375+
376+
Args:
377+
output: shape = [num_generation_tokens, num_heads, head_size]
378+
query: shape = [num_generation_tokens, num_heads, head_size]
379+
key_cache: shape = [num_blocks, num_heads, head_size/x,
380+
block_size, x]
381+
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
382+
input_metadata: metadata for paged attention.
383+
"""
384+
block_size = value_cache.shape[3]
385+
attention_ops.single_query_cached_kv_attention(
386+
output,
387+
query,
388+
key_cache,
389+
value_cache,
390+
self.scale,
391+
input_metadata.block_tables,
392+
input_metadata.context_lens,
393+
block_size,
394+
input_metadata.max_context_len,
395+
self.alibi_slopes,
396+
)

vllm/model_executor/model_loader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from transformers import PretrainedConfig
77

88
from vllm.config import ModelConfig
9-
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
10-
GPTNeoXForCausalLM, LlamaForCausalLM,
11-
OPTForCausalLM)
9+
from vllm.model_executor.models import * # pylint: disable=wildcard-import
1210
from vllm.model_executor.weight_utils import initialize_dummy_weights
1311

1412
# TODO(woosuk): Lazy-load the model classes.
1513
_MODEL_REGISTRY = {
14+
"BloomForCausalLM": BloomForCausalLM,
1615
"GPT2LMHeadModel": GPT2LMHeadModel,
1716
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
1817
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from vllm.model_executor.models.bloom import BloomForCausalLM
12
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
23
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
34
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
45
from vllm.model_executor.models.llama import LlamaForCausalLM
56
from vllm.model_executor.models.opt import OPTForCausalLM
67

78
__all__ = [
9+
"BloomForCausalLM",
810
"GPT2LMHeadModel",
911
"GPTBigCodeForCausalLM",
1012
"GPTNeoXForCausalLM",

0 commit comments

Comments
 (0)