Skip to content

Commit aaaec52

Browse files
authored
[Bugfix][Model] Mixtral: use unused head_dim config argument (#14961)
Signed-off-by: Quentin Torroba <[email protected]>
1 parent e1eb45d commit aaaec52

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

vllm/model_executor/models/mixtral.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class MixtralAttention(nn.Module):
111111

112112
def __init__(
113113
self,
114+
config: MixtralConfig,
114115
hidden_size: int,
115116
num_heads: int,
116117
num_kv_heads: int,
@@ -136,7 +137,9 @@ def __init__(
136137
# the KV heads across multiple tensor parallel GPUs.
137138
assert tp_size % self.total_num_kv_heads == 0
138139
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
139-
self.head_dim = hidden_size // self.total_num_heads
140+
# MixtralConfig has an optional head_dim argument
141+
self.head_dim = getattr(config, "head_dim",
142+
self.hidden_size // self.total_num_heads)
140143
self.q_size = self.num_heads * self.head_dim
141144
self.kv_size = self.num_kv_heads * self.head_dim
142145
self.scaling = self.head_dim**-0.5
@@ -200,6 +203,7 @@ def __init__(
200203
# Requires transformers > 4.32.0
201204
rope_theta = getattr(config, "rope_theta", 10000)
202205
self.self_attn = MixtralAttention(
206+
config=config,
203207
hidden_size=self.hidden_size,
204208
num_heads=config.num_attention_heads,
205209
max_position=config.max_position_embeddings,

vllm/model_executor/models/mixtral_quant.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class MixtralAttention(nn.Module):
165165

166166
def __init__(
167167
self,
168+
config: MixtralConfig,
168169
hidden_size: int,
169170
num_heads: int,
170171
num_kv_heads: int,
@@ -190,7 +191,9 @@ def __init__(
190191
# the KV heads across multiple tensor parallel GPUs.
191192
assert tp_size % self.total_num_kv_heads == 0
192193
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
193-
self.head_dim = hidden_size // self.total_num_heads
194+
# MixtralConfig has an optional head_dim argument
195+
self.head_dim = getattr(config, "head_dim",
196+
self.hidden_size // self.total_num_heads)
194197
self.q_size = self.num_heads * self.head_dim
195198
self.kv_size = self.num_kv_heads * self.head_dim
196199
self.scaling = self.head_dim**-0.5
@@ -252,6 +255,7 @@ def __init__(
252255
# Requires transformers > 4.32.0
253256
rope_theta = getattr(config, "rope_theta", 10000)
254257
self.self_attn = MixtralAttention(
258+
config=config,
255259
hidden_size=self.hidden_size,
256260
num_heads=config.num_attention_heads,
257261
max_position=config.max_position_embeddings,

0 commit comments

Comments
 (0)