Skip to content

Commit 9f6be86

Browse files
authored
Fix config for Falcon (#1164)
1 parent f187877 commit 9f6be86

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

vllm/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def get_head_size(self) -> int:
135135
# FIXME(woosuk): This may not be true for all models.
136136
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
137137

138-
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
138+
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
139+
"""Returns the number of KV heads per GPU worker."""
139140
# For GPTBigCode & Falcon:
140141
# Note: for falcon, when new_decoder_architecture is True, the
141142
# multi_query flag is ignored and we use n_head_kv for the number of
@@ -147,11 +148,15 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
147148
if not new_decoder_arch_falcon and getattr(self.hf_config,
148149
"multi_query", False):
149150
# Multi-query attention, only one KV head.
151+
# Currently, tensor parallelism is not supported in this case.
150152
return 1
151153
# For Falcon:
152154
if getattr(self.hf_config, "n_head_kv", None) is not None:
153155
return (self.hf_config.n_head_kv //
154156
parallel_config.tensor_parallel_size)
157+
if getattr(self.hf_config, "num_kv_heads", None) is not None:
158+
return (self.hf_config.num_kv_heads //
159+
parallel_config.tensor_parallel_size)
155160
# For LLaMA-2:
156161
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
157162
return (self.hf_config.num_key_value_heads //

vllm/worker/cache_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333

3434
self.head_size = model_config.get_head_size()
3535
self.num_layers = model_config.get_num_layers(parallel_config)
36-
self.num_heads = model_config.get_num_heads(parallel_config)
36+
self.num_heads = model_config.get_num_kv_heads(parallel_config)
3737
self.dtype = model_config.dtype
3838

3939
self.block_size = cache_config.block_size
@@ -146,7 +146,7 @@ def get_cache_block_size(
146146
parallel_config: ParallelConfig,
147147
) -> int:
148148
head_size = model_config.get_head_size()
149-
num_heads = model_config.get_num_heads(parallel_config)
149+
num_heads = model_config.get_num_kv_heads(parallel_config)
150150
num_layers = model_config.get_num_layers(parallel_config)
151151

152152
key_cache_block = block_size * num_heads * head_size

0 commit comments

Comments
 (0)