@@ -135,7 +135,8 @@ def get_head_size(self) -> int:
135
135
# FIXME(woosuk): This may not be true for all models.
136
136
return self .hf_config .hidden_size // self .hf_config .num_attention_heads
137
137
138
- def get_num_heads (self , parallel_config : "ParallelConfig" ) -> int :
138
+ def get_num_kv_heads (self , parallel_config : "ParallelConfig" ) -> int :
139
+ """Returns the number of KV heads per GPU worker."""
139
140
# For GPTBigCode & Falcon:
140
141
# Note: for falcon, when new_decoder_architecture is True, the
141
142
# multi_query flag is ignored and we use n_head_kv for the number of
@@ -147,11 +148,15 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
147
148
if not new_decoder_arch_falcon and getattr (self .hf_config ,
148
149
"multi_query" , False ):
149
150
# Multi-query attention, only one KV head.
151
+ # Currently, tensor parallelism is not supported in this case.
150
152
return 1
151
153
# For Falcon:
152
154
if getattr (self .hf_config , "n_head_kv" , None ) is not None :
153
155
return (self .hf_config .n_head_kv //
154
156
parallel_config .tensor_parallel_size )
157
+ if getattr (self .hf_config , "num_kv_heads" , None ) is not None :
158
+ return (self .hf_config .num_kv_heads //
159
+ parallel_config .tensor_parallel_size )
155
160
# For LLaMA-2:
156
161
if getattr (self .hf_config , "num_key_value_heads" , None ) is not None :
157
162
return (self .hf_config .num_key_value_heads //
0 commit comments