Skip to content

Commit 28e616c

Browse files
authored
fix qwen-14b model (#1173)
1 parent 30e7752 commit 28e616c

File tree

2 files changed

+32
-43
lines changed

2 files changed

+32
-43
lines changed

vllm/model_executor/models/qwen.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,17 @@ class QWenBlock(nn.Module):
141141

142142
def __init__(self, config: QWenConfig):
143143
super().__init__()
144-
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
144+
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
145145

146146
rope_theta = getattr(config, "rope_theta", 10000)
147-
self.attn = QWenAttention(config.n_embd,
147+
self.attn = QWenAttention(config.hidden_size,
148148
config.num_attention_heads,
149149
config.max_position_embeddings,
150150
rope_theta=rope_theta)
151151

152-
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
152+
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
153153

154-
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
154+
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
155155

156156
def forward(
157157
self,
@@ -190,11 +190,11 @@ def __init__(self, config: QWenConfig):
190190

191191
vocab_size = ((config.vocab_size + 63) // 64) * 64
192192
self.wte = VocabParallelEmbedding(vocab_size,
193-
config.n_embd,
193+
config.hidden_size,
194194
perform_initialization=False)
195195
self.h = nn.ModuleList(
196196
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
197-
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
197+
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
198198

199199
def forward(
200200
self,
@@ -230,7 +230,7 @@ def __init__(self, config: QWenConfig):
230230
self.transformer = QWenModel(config)
231231
vocab_size = ((config.vocab_size + 63) // 64) * 64
232232
self.lm_head = ColumnParallelLinear(
233-
config.n_embd,
233+
config.hidden_size,
234234
vocab_size,
235235
bias=False,
236236
gather_output=False,

vllm/transformers_utils/configs/qwen.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,54 @@
77
class QWenConfig(PretrainedConfig):
88
model_type = "qwen"
99
keys_to_ignore_at_inference = ["past_key_values"]
10-
attribute_map = {
11-
"hidden_size": "n_embd",
12-
"num_attention_heads": "n_head",
13-
"max_position_embeddings": "n_positions",
14-
"num_hidden_layers": "n_layer",
15-
}
1610

1711
def __init__(
1812
self,
19-
vocab_size=151851,
20-
n_embd=4096,
21-
n_layer=32,
22-
n_head=32,
23-
n_inner=None,
24-
embd_pdrop=0.0,
25-
attn_pdrop=0.0,
26-
layer_norm_epsilon=1e-5,
13+
vocab_size=151936,
14+
hidden_size=4096,
15+
num_hidden_layers=32,
16+
num_attention_heads=32,
17+
emb_dropout_prob=0.0,
18+
attn_dropout_prob=0.0,
19+
layer_norm_epsilon=1e-6,
2720
initializer_range=0.02,
21+
max_position_embeddings=8192,
2822
scale_attn_weights=True,
2923
use_cache=True,
30-
eos_token_id=151643,
31-
apply_residual_connection_post_layernorm=False,
32-
bf16=True,
24+
bf16=False,
25+
fp16=False,
26+
fp32=False,
3327
kv_channels=128,
3428
rotary_pct=1.0,
3529
rotary_emb_base=10000,
36-
use_dynamic_ntk=False,
37-
use_logn_attn=False,
38-
use_flash_attn=True,
39-
ffn_hidden_size=22016,
30+
use_dynamic_ntk=True,
31+
use_logn_attn=True,
32+
use_flash_attn="auto",
33+
intermediate_size=22016,
4034
no_bias=True,
4135
tie_word_embeddings=False,
4236
**kwargs,
4337
):
44-
self.eos_token_id = eos_token_id
45-
super().__init__(eos_token_id=eos_token_id,
46-
tie_word_embeddings=tie_word_embeddings,
47-
**kwargs)
48-
4938
self.vocab_size = vocab_size
50-
self.n_embd = n_embd
51-
self.n_layer = n_layer
52-
self.n_head = n_head
53-
self.n_inner = n_inner
54-
self.embd_pdrop = embd_pdrop
55-
self.attn_pdrop = attn_pdrop
39+
self.hidden_size = hidden_size
40+
self.intermediate_size = intermediate_size
41+
self.num_hidden_layers = num_hidden_layers
42+
self.num_attention_heads = num_attention_heads
43+
self.emb_dropout_prob = emb_dropout_prob
44+
self.attn_dropout_prob = attn_dropout_prob
5645
self.layer_norm_epsilon = layer_norm_epsilon
5746
self.initializer_range = initializer_range
5847
self.scale_attn_weights = scale_attn_weights
5948
self.use_cache = use_cache
60-
self.apply_residual_connection_post_layernorm = (
61-
apply_residual_connection_post_layernorm)
49+
self.max_position_embeddings = max_position_embeddings
6250
self.bf16 = bf16
51+
self.fp16 = fp16
52+
self.fp32 = fp32
6353
self.kv_channels = kv_channels
6454
self.rotary_pct = rotary_pct
6555
self.rotary_emb_base = rotary_emb_base
6656
self.use_dynamic_ntk = use_dynamic_ntk
6757
self.use_logn_attn = use_logn_attn
6858
self.use_flash_attn = use_flash_attn
69-
self.ffn_hidden_size = ffn_hidden_size
7059
self.no_bias = no_bias
71-
self.tie_word_embeddings = tie_word_embeddings
60+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

0 commit comments

Comments
 (0)