Skip to content

Commit 2a05201

Browse files
authored
[Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)
Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436. This PR enables the following checkpoint loading features for Mixtral: Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model Supports static or dynamic activation quantization with static weight quantization (all per tensor) Supports different scales for each expert weight Supports Fp8 in QKV layer Notes: The Expert Gate/Router always runs at half / full precision for now. If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
1 parent 36fb68f commit 2a05201

File tree

2 files changed

+122
-53
lines changed

2 files changed

+122
-53
lines changed

tests/kernels/test_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
7777
for i in range(config.num_local_experts):
7878
weights = (hf_moe.experts[i].w1.weight.data,
7979
hf_moe.experts[i].w3.weight.data)
80-
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
81-
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
80+
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
81+
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
8282

8383
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
8484
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")

vllm/model_executor/models/mixtral.py

Lines changed: 120 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def __init__(
7878
self.top_k = top_k
7979
self.hidden_size = hidden_size
8080
self.intermediate_size = intermediate_size // self.tp_size
81+
self.quant_config = quant_config
82+
8183
# FIXME(pcmoritz): Make this more general to support different
8284
# quantization schemes
8385
self.use_fp8 = isinstance(quant_config, Fp8Config)
@@ -86,55 +88,79 @@ def __init__(
8688
params_dtype = torch.get_default_dtype()
8789
self.params_dtype = params_dtype
8890

91+
# Gate always runs at half / full precision for now.
8992
self.gate = ReplicatedLinear(self.hidden_size,
9093
self.num_total_experts,
9194
bias=False,
9295
params_dtype=self.params_dtype,
9396
quant_config=None)
9497

95-
self.ws = nn.Parameter(
98+
if self.use_fp8:
99+
params_dtype = torch.float8_e4m3fn
100+
101+
self.w13_weight = nn.Parameter(
96102
torch.empty(self.num_total_experts,
97103
2 * self.intermediate_size,
98104
self.hidden_size,
99-
dtype=self.params_dtype))
100-
self.w2s = nn.Parameter(
105+
dtype=params_dtype))
106+
self.w2_weight = nn.Parameter(
101107
torch.empty(self.num_total_experts,
102108
self.hidden_size,
103109
self.intermediate_size,
104-
dtype=self.params_dtype))
110+
dtype=params_dtype))
105111

106-
set_weight_attrs(self.ws, {
112+
set_weight_attrs(self.w13_weight, {
107113
"weight_loader": self.weight_loader,
108114
})
109-
set_weight_attrs(self.w2s, {
115+
set_weight_attrs(self.w2_weight, {
110116
"weight_loader": self.weight_loader,
111117
})
112118

113-
# Scaling factors for FP8 weights
114-
self.ws_scale = nn.Parameter(
115-
torch.ones(self.num_total_experts, dtype=torch.float32),
116-
requires_grad=False) if self.use_fp8 else None
117-
self.w2s_scale = nn.Parameter(
118-
torch.ones(self.num_total_experts, dtype=torch.float32),
119-
requires_grad=False) if self.use_fp8 else None
120-
121-
# Scaling factors for FP8 activations
122-
need_act_scales = (self.use_fp8
123-
and quant_config.activation_scheme == "static")
124-
self.as_scale = nn.Parameter(
125-
torch.zeros(1, dtype=torch.float32),
126-
requires_grad=False) if need_act_scales else None
127-
self.a2s_scale = nn.Parameter(
128-
torch.zeros(1, dtype=torch.float32),
129-
requires_grad=False) if need_act_scales else None
130-
131-
if need_act_scales:
132-
set_weight_attrs(self.as_scale, {
133-
"weight_loader": self.weight_loader,
134-
})
135-
set_weight_attrs(self.a2s_scale, {
136-
"weight_loader": self.weight_loader,
137-
})
119+
# Used for fp8.
120+
self.w13_scale = None
121+
self.w2_scale = None
122+
self.a13_scale = None
123+
self.a2_scale = None
124+
125+
if self.use_fp8:
126+
# WEIGHT_SCALE (for fp8)
127+
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
128+
dtype=torch.float32),
129+
requires_grad=False)
130+
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
131+
dtype=torch.float32),
132+
requires_grad=False)
133+
134+
# If loading fp8 checkpoint, pass the weight loaders.
135+
# If loading an fp16 checkpoint, do not (we will quantize in
136+
# process_weights_after_loading()
137+
if quant_config.is_checkpoint_fp8_serialized:
138+
set_weight_attrs(self.w13_scale, {
139+
"weight_loader": self.weight_loader,
140+
})
141+
set_weight_attrs(self.w2_scale, {
142+
"weight_loader": self.weight_loader,
143+
})
144+
145+
# ACT_SCALE (for fp8)
146+
if quant_config.activation_scheme == "static":
147+
if not quant_config.is_checkpoint_fp8_serialized:
148+
raise ValueError(
149+
"Found static activation scheme for checkpoint that "
150+
"was not serialized fp8.")
151+
self.a13_scale = nn.Parameter(torch.zeros(
152+
self.num_total_experts, dtype=torch.float32),
153+
requires_grad=False)
154+
self.a2_scale = nn.Parameter(torch.zeros(
155+
self.num_total_experts, dtype=torch.float32),
156+
requires_grad=False)
157+
158+
set_weight_attrs(self.a13_scale, {
159+
"weight_loader": self.weight_loader,
160+
})
161+
set_weight_attrs(self.a2_scale, {
162+
"weight_loader": self.weight_loader,
163+
})
138164

139165
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
140166
weight_name: str, expert_id: int):
@@ -149,38 +175,67 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
149175
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
150176
if weight_name.endswith("w2.weight"):
151177
param_data[expert_id, :, :] = loaded_weight[:, shard]
152-
if "act_scale" in weight_name:
153-
param_data[:] = param_data[:].max(loaded_weight)
178+
if "act_scale" in weight_name or "weight_scale" in weight_name:
179+
param_data[expert_id] = loaded_weight
154180

155181
def process_weights_after_loading(self):
156-
if self.use_fp8:
157-
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
158-
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
182+
# Fp8 is the only case where we need to process after loading.
183+
if not self.use_fp8:
184+
return
185+
186+
# If checkpoint is fp16, quantize here.
187+
if not self.quant_config.is_checkpoint_fp8_serialized:
188+
w13_weight = torch.empty_like(self.w13_weight.data,
189+
dtype=torch.float8_e4m3fn)
190+
w2_weight = torch.empty_like(self.w2_weight.data,
191+
dtype=torch.float8_e4m3fn)
159192
for expert in range(self.num_total_experts):
160-
ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
161-
self.ws.data[expert, :, :])
162-
w2s[expert, :, :], self.w2s_scale[
163-
expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
164-
self.ws = nn.Parameter(ws, requires_grad=False)
165-
self.w2s = nn.Parameter(w2s, requires_grad=False)
193+
w13_weight[expert, :, :], self.w13_scale[
194+
expert] = ops.scaled_fp8_quant(
195+
self.w13_weight.data[expert, :, :])
196+
w2_weight[expert, :, :], self.w2_scale[
197+
expert] = ops.scaled_fp8_quant(
198+
self.w2_weight.data[expert, :, :])
199+
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
200+
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
201+
202+
# If checkpoint is fp8 + static, cleanup act_scales.
203+
# Since state_dict has an act_scale per expert but our kernels
204+
# are passed one act_scale shared across all experts.
205+
elif self.quant_config.activation_scheme == "static":
206+
if self.a13_scale is None or self.a2_scale is None:
207+
raise ValueError(
208+
"QuantConfig has static quantization, but found "
209+
"activation scales are None.")
210+
211+
if (not all_close_1d(self.a13_scale)
212+
or not all_close_1d(self.a2_scale)):
213+
print_warning_once(
214+
"Found act_scales that are not equal for fp8 MoE layer. "
215+
"Using the maximum across experts for each layer. ")
216+
217+
self.a13_scale = nn.Parameter(self.a13_scale.max(),
218+
requires_grad=False)
219+
self.a2_scale = nn.Parameter(self.a2_scale.max(),
220+
requires_grad=False)
166221

167222
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
168223
num_tokens, hidden_size = hidden_states.shape
169224
hidden_states = hidden_states.view(-1, self.hidden_size)
170225
# router_logits: (num_tokens, n_experts)
171226
router_logits, _ = self.gate(hidden_states)
172227
final_hidden_states = fused_moe(hidden_states,
173-
self.ws,
174-
self.w2s,
228+
self.w13_weight,
229+
self.w2_weight,
175230
router_logits,
176231
self.top_k,
177232
renormalize=True,
178233
inplace=True,
179234
use_fp8=self.use_fp8,
180-
w1_scale=self.ws_scale,
181-
w2_scale=self.w2s_scale,
182-
a1_scale=self.as_scale,
183-
a2_scale=self.a2s_scale)
235+
w1_scale=self.w13_scale,
236+
w2_scale=self.w2_scale,
237+
a1_scale=self.a13_scale,
238+
a2_scale=self.a2_scale)
184239

185240
if self.tp_size > 1:
186241
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -222,7 +277,9 @@ def __init__(self,
222277
self.rope_theta = rope_theta
223278
self.sliding_window = sliding_window
224279

225-
if isinstance(quant_config, Fp8Config):
280+
if isinstance(
281+
quant_config,
282+
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
226283
print_warning_once(
227284
"For Mixtral FP8 quantization, we currently do not quantize "
228285
"the attention layers until their FP8 performance is improved."
@@ -461,16 +518,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
461518
]
462519

463520
expert_params_mapping = [
521+
# These are the weight scales for the experts
522+
# (param_name, weight_name, expert_id)
523+
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
524+
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
525+
for expert_id in range(self.config.num_local_experts)
526+
for weight_name in ["w1", "w2", "w3"]
527+
] + [
464528
# These are the weights for the experts
465529
# (param_name, weight_name, expert_id)
466-
("ws" if weight_name in ["w1", "w3"] else "w2s",
530+
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
467531
f"experts.{expert_id}.{weight_name}.weight", expert_id)
468532
for expert_id in range(self.config.num_local_experts)
469533
for weight_name in ["w1", "w2", "w3"]
470534
] + [
471535
# These are the activation scales for the experts
472536
# (param_name, weight_name, expert_id)
473-
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
537+
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
474538
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
475539
for expert_id in range(self.config.num_local_experts)
476540
for weight_name in ["w1", "w2", "w3"]
@@ -512,3 +576,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
512576
weight_loader = getattr(param, "weight_loader",
513577
default_weight_loader)
514578
weight_loader(param, loaded_weight)
579+
580+
581+
def all_close_1d(x: torch.Tensor) -> bool:
582+
assert len(x.shape) == 1
583+
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

0 commit comments

Comments
 (0)