Skip to content

Commit de89472

Browse files
authored
Fix the issue for AquilaChat2-* models (#1339)
1 parent e7c8555 commit de89472

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

vllm/model_executor/model_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# TODO(woosuk): Lazy-load the model classes.
1515
_MODEL_REGISTRY = {
1616
"AquilaModel": AquilaForCausalLM,
17+
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
1718
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
1819
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
1920
"BloomForCausalLM": BloomForCausalLM,

vllm/model_executor/models/aquila.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
rotary_dim=self.head_dim,
148148
base=self.rope_theta,
149149
max_position=self.max_position_embeddings,
150+
num_kv_heads=self.num_kv_heads,
150151
)
151152

152153
def forward(
@@ -177,7 +178,7 @@ def __init__(self, config: AquilaConfig):
177178
self.self_attn = AquilaAttention(
178179
hidden_size=self.hidden_size,
179180
num_heads=config.num_attention_heads,
180-
num_kv_heads=config.num_attention_heads,
181+
num_kv_heads=config.num_key_value_heads,
181182
rope_theta=rope_theta,
182183
max_position_embeddings=max_position_embeddings,
183184
)
@@ -308,7 +309,7 @@ def load_weights(self,
308309
q_proj_shard_size = (self.config.hidden_size // tp_size)
309310
kv_proj_shard_size = (self.config.hidden_size //
310311
self.config.num_attention_heads *
311-
self.config.num_attention_heads // tp_size)
312+
self.config.num_key_value_heads // tp_size)
312313
attention_weight_specs = [
313314
# (weight_name, shard_size, offset)
314315
("q_proj", q_proj_shard_size, 0),

vllm/transformers_utils/configs/aquila.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
intermediate_size=11008,
3434
num_hidden_layers=32,
3535
num_attention_heads=32,
36+
num_key_value_heads=None,
3637
hidden_act="silu",
3738
max_position_embeddings=2048,
3839
initializer_range=0.006,
@@ -49,6 +50,11 @@ def __init__(
4950
self.hidden_size = hidden_size
5051
self.intermediate_size = intermediate_size
5152
self.num_hidden_layers = num_hidden_layers
53+
# for backward compatibility
54+
if num_key_value_heads is None:
55+
num_key_value_heads = num_attention_heads
56+
57+
self.num_key_value_heads = num_key_value_heads
5258
self.num_attention_heads = num_attention_heads
5359
self.hidden_act = hidden_act
5460
self.initializer_range = initializer_range

0 commit comments

Comments
 (0)