Skip to content

Commit 2b98101

Browse files
authored
Fix Baichuan2-7B-Chat (#1987)
1 parent 6ccc0bf commit 2b98101

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm/model_executor/models/baichuan.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,12 +366,16 @@ def load_weights(self,
366366
weight_loader(param, loaded_weight)
367367

368368

369-
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
369+
class BaichuanForCausalLM(BaiChuanBaseForCausalLM
370+
): # baichuan 13b, baichuan2 13b, baichuan2 7b
370371

371372
def __init__(self,
372373
config,
373374
linear_method: Optional[LinearMethodBase] = None):
374-
super().__init__(config, "ALIBI", linear_method)
375+
if config.hidden_size == 4096: # baichuan2 7b
376+
super().__init__(config, "ROPE", linear_method)
377+
else: # baichuan 13b, baichuan2 13b
378+
super().__init__(config, "ALIBI", linear_method)
375379

376380

377381
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b

0 commit comments

Comments
 (0)