Skip to content

Commit 1b0bd0f

Browse files
authored
Add Falcon support (new) (#592)
1 parent 20044ca commit 1b0bd0f

File tree

16 files changed

+680
-122
lines changed

16 files changed

+680
-122
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
4444

4545
- Baichuan-7B (`baichuan-inc/Baichuan-7B`)
4646
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
47+
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
4748
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
4849
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
4950
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)

csrc/pos_encoding_kernels.cu

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel(
1010
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
1111
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
1212
const int rot_dim,
13-
const int stride,
13+
const int query_stride,
14+
const int key_stride,
1415
const int num_heads,
1516
const int num_kv_heads,
1617
const int head_size) {
@@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel(
2324
const int nq = num_heads * embed_dim;
2425
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
2526
const int head_idx = i / embed_dim;
26-
const int token_head = token_idx * stride + head_idx * head_size;
27+
const int token_head = token_idx * query_stride + head_idx * head_size;
2728

2829
const int rot_offset = i % embed_dim;
2930
const int x_index = rot_offset;
3031
const int y_index = embed_dim + rot_offset;
3132

32-
const int out_x = token_idx * stride + head_idx * head_size + x_index;
33-
const int out_y = token_idx * stride + head_idx * head_size + y_index;
33+
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
34+
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
3435

3536
const scalar_t cos = __ldg(cache_ptr + x_index);
3637
const scalar_t sin = __ldg(cache_ptr + y_index);
@@ -39,13 +40,27 @@ __global__ void rotary_embedding_neox_kernel(
3940
const scalar_t q_y = query[token_head + y_index];
4041
query[out_x] = q_x * cos - q_y * sin;
4142
query[out_y] = q_y * cos + q_x * sin;
43+
}
44+
45+
const int nk = num_kv_heads * embed_dim;
46+
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
47+
const int head_idx = i / embed_dim;
48+
const int token_head = token_idx * key_stride + head_idx * head_size;
49+
50+
const int rot_offset = i % embed_dim;
51+
const int x_index = rot_offset;
52+
const int y_index = embed_dim + rot_offset;
53+
54+
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
55+
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
56+
57+
const scalar_t cos = __ldg(cache_ptr + x_index);
58+
const scalar_t sin = __ldg(cache_ptr + y_index);
4259

43-
if (head_idx < num_kv_heads) {
44-
const scalar_t k_x = key[token_head + x_index];
45-
const scalar_t k_y = key[token_head + y_index];
46-
key[out_x] = k_x * cos - k_y * sin;
47-
key[out_y] = k_y * cos + k_x * sin;
48-
}
60+
const scalar_t k_x = key[token_head + x_index];
61+
const scalar_t k_y = key[token_head + y_index];
62+
key[out_x] = k_x * cos - k_y * sin;
63+
key[out_y] = k_y * cos + k_x * sin;
4964
}
5065
}
5166

@@ -62,8 +77,8 @@ void rotary_embedding_neox(
6277
int rot_dim = cos_sin_cache.size(1);
6378
int num_heads = query.size(1) / head_size;
6479
int num_kv_heads = key.size(1) / head_size;
65-
int stride = query.stride(0);
66-
TORCH_CHECK(stride == key.stride(0));
80+
int query_stride = query.stride(0);
81+
int key_stride = key.stride(0);
6782

6883
dim3 grid(num_tokens);
6984
dim3 block(std::min(num_heads * rot_dim / 2, 512));
@@ -80,7 +95,8 @@ void rotary_embedding_neox(
8095
key.data_ptr<scalar_t>(),
8196
cos_sin_cache.data_ptr<scalar_t>(),
8297
rot_dim,
83-
stride,
98+
query_stride,
99+
key_stride,
84100
num_heads,
85101
num_kv_heads,
86102
head_size);

docs/source/models/supported_models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it.
2020
* - :code:`BloomForCausalLM`
2121
- BLOOM, BLOOMZ, BLOOMChat
2222
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
23+
* - :code:`FalconForCausalLM`
24+
- Falcon
25+
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
2326
* - :code:`GPT2LMHeadModel`
2427
- GPT-2
2528
- :code:`gpt2`, :code:`gpt2-xl`, etc.

examples/llm_engine_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
1010

1111
# Test the following prompts.
1212
test_prompts = [
13-
("A robot may not injure a human being", SamplingParams()),
13+
("A robot may not injure a human being",
14+
SamplingParams(temperature=0.0)),
1415
("To be or not to be,",
1516
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
1617
("What is the meaning of life?",

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ def get_head_size(self) -> int:
9494
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
9595

9696
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
97-
# For GPTBigCode:
98-
if getattr(self.hf_config, "multi_query", False):
97+
# For GPTBigCode & Falcon:
98+
# Note: for falcon, when new_decoder_architecture is True, the
99+
# multi_query flag is ignored and we use n_head_kv for the number of
100+
# KV heads.
101+
if (getattr(self.hf_config, "multi_query", False) and
102+
(self.hf_config.model_type == "falcon" and
103+
not getattr(self.hf_config, "new_decoder_architecture", False))):
99104
# Multi-query attention, only one KV head.
100105
return 1
101106
# For Falcon:

vllm/model_executor/layers/attention.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,13 @@ def forward(
314314
class PagedAttentionWithALiBi(PagedAttention):
315315
"""PagedAttention with ALiBi attention bias."""
316316

317-
def __init__(
318-
self,
319-
num_heads: int,
320-
head_size: int,
321-
scale: float,
322-
slopes: List[float],
323-
) -> None:
324-
super().__init__(num_heads, head_size, scale)
317+
def __init__(self,
318+
num_heads: int,
319+
head_size: int,
320+
scale: float,
321+
slopes: List[float],
322+
num_kv_heads: Optional[int] = None) -> None:
323+
super().__init__(num_heads, head_size, scale, num_kv_heads)
325324
assert len(slopes) == num_heads
326325

327326
slopes = torch.tensor(slopes, dtype=torch.float32)
@@ -334,6 +333,11 @@ def set_attn_bias(self, input_metadata: InputMetadata) -> None:
334333
# Generates ALiBi mask for each prompt.
335334
for prompt_len in input_metadata.prompt_lens:
336335
bias = torch.arange(prompt_len)
336+
# Note(zhuohan): HF uses
337+
# `bias = bias[None, :].repeat(prompt_len, 1)`
338+
# here. We find that both biases give the same results, but
339+
# the bias below more accurately follows the original ALiBi
340+
# paper.
337341
bias = bias[None, :] - bias[:, None]
338342
bias = bias.to(self.alibi_slopes.device)
339343

@@ -363,10 +367,17 @@ def multi_query_kv_attention(
363367
Args:
364368
output: shape = [num_prompt_tokens, num_heads, head_size]
365369
query: shape = [num_prompt_tokens, num_heads, head_size]
366-
key: shape = [num_prompt_tokens, num_heads, head_size]
367-
value: shape = [num_prompt_tokens, num_heads, head_size]
370+
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
371+
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
368372
input_metadata: metadata for paged attention.
369373
"""
374+
if self.num_kv_heads != self.num_heads:
375+
# Project the key and value tensors to the desired number of heads.
376+
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
377+
value = torch.repeat_interleave(value,
378+
self.num_queries_per_kv,
379+
dim=1)
380+
370381
# FIXME(woosuk): Because xformers does not support dynamic sequence
371382
# lengths with custom attention bias, we process each prompt one by
372383
# one. This is inefficient, especially when we have many short prompts.
@@ -400,9 +411,10 @@ def single_query_cached_kv_attention(
400411
Args:
401412
output: shape = [num_generation_tokens, num_heads, head_size]
402413
query: shape = [num_generation_tokens, num_heads, head_size]
403-
key_cache: shape = [num_blocks, num_heads, head_size/x,
414+
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
404415
block_size, x]
405-
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
416+
value_cache: shape = [num_blocks, num_kv_heads, head_size,
417+
block_size]
406418
input_metadata: metadata for paged attention.
407419
"""
408420
block_size = value_cache.shape[3]

vllm/model_executor/model_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
1515
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
1616
"BloomForCausalLM": BloomForCausalLM,
17+
"FalconForCausalLM": FalconForCausalLM,
1718
"GPT2LMHeadModel": GPT2LMHeadModel,
1819
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
1920
"GPTJForCausalLM": GPTJForCausalLM,
@@ -22,6 +23,7 @@
2223
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
2324
"MPTForCausalLM": MPTForCausalLM,
2425
"OPTForCausalLM": OPTForCausalLM,
26+
"RWForCausalLM": FalconForCausalLM,
2527
}
2628

2729

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
22
from vllm.model_executor.models.bloom import BloomForCausalLM
3+
from vllm.model_executor.models.falcon import FalconForCausalLM
34
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
45
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
56
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
@@ -12,6 +13,7 @@
1213
"BaiChuanForCausalLM",
1314
"BaichuanForCausalLM",
1415
"BloomForCausalLM",
16+
"FalconForCausalLM",
1517
"GPT2LMHeadModel",
1618
"GPTBigCodeForCausalLM",
1719
"GPTJForCausalLM",

0 commit comments

Comments
 (0)