File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -111,6 +111,7 @@ class MixtralAttention(nn.Module):
111
111
112
112
def __init__ (
113
113
self ,
114
+ config : MixtralConfig ,
114
115
hidden_size : int ,
115
116
num_heads : int ,
116
117
num_kv_heads : int ,
@@ -136,7 +137,9 @@ def __init__(
136
137
# the KV heads across multiple tensor parallel GPUs.
137
138
assert tp_size % self .total_num_kv_heads == 0
138
139
self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
139
- self .head_dim = hidden_size // self .total_num_heads
140
+ # MixtralConfig has an optional head_dim argument
141
+ self .head_dim = getattr (config , "head_dim" ,
142
+ self .hidden_size // self .total_num_heads )
140
143
self .q_size = self .num_heads * self .head_dim
141
144
self .kv_size = self .num_kv_heads * self .head_dim
142
145
self .scaling = self .head_dim ** - 0.5
@@ -200,6 +203,7 @@ def __init__(
200
203
# Requires transformers > 4.32.0
201
204
rope_theta = getattr (config , "rope_theta" , 10000 )
202
205
self .self_attn = MixtralAttention (
206
+ config = config ,
203
207
hidden_size = self .hidden_size ,
204
208
num_heads = config .num_attention_heads ,
205
209
max_position = config .max_position_embeddings ,
Original file line number Diff line number Diff line change @@ -165,6 +165,7 @@ class MixtralAttention(nn.Module):
165
165
166
166
def __init__ (
167
167
self ,
168
+ config : MixtralConfig ,
168
169
hidden_size : int ,
169
170
num_heads : int ,
170
171
num_kv_heads : int ,
@@ -190,7 +191,9 @@ def __init__(
190
191
# the KV heads across multiple tensor parallel GPUs.
191
192
assert tp_size % self .total_num_kv_heads == 0
192
193
self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
193
- self .head_dim = hidden_size // self .total_num_heads
194
+ # MixtralConfig has an optional head_dim argument
195
+ self .head_dim = getattr (config , "head_dim" ,
196
+ self .hidden_size // self .total_num_heads )
194
197
self .q_size = self .num_heads * self .head_dim
195
198
self .kv_size = self .num_kv_heads * self .head_dim
196
199
self .scaling = self .head_dim ** - 0.5
@@ -252,6 +255,7 @@ def __init__(
252
255
# Requires transformers > 4.32.0
253
256
rope_theta = getattr (config , "rope_theta" , 10000 )
254
257
self .self_attn = MixtralAttention (
258
+ config = config ,
255
259
hidden_size = self .hidden_size ,
256
260
num_heads = config .num_attention_heads ,
257
261
max_position = config .max_position_embeddings ,
You can’t perform that action at this time.
0 commit comments