Skip to content

Commit 898285c

Browse files
authored
fix: CUDA error when inferencing with Falcon-40B base model (#992)
1 parent a62de9e commit 898285c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vllm/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
114114
# Note: for falcon, when new_decoder_architecture is True, the
115115
# multi_query flag is ignored and we use n_head_kv for the number of
116116
# KV heads.
117+
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
117118
new_decoder_arch_falcon = (
118-
self.hf_config.model_type == "falcon"
119+
self.hf_config.model_type in falcon_model_types
119120
and getattr(self.hf_config, "new_decoder_architecture", False))
120121
if not new_decoder_arch_falcon and getattr(self.hf_config,
121122
"multi_query", False):

0 commit comments

Comments
 (0)