@@ -78,6 +78,8 @@ def __init__(
78
78
self .top_k = top_k
79
79
self .hidden_size = hidden_size
80
80
self .intermediate_size = intermediate_size // self .tp_size
81
+ self .quant_config = quant_config
82
+
81
83
# FIXME(pcmoritz): Make this more general to support different
82
84
# quantization schemes
83
85
self .use_fp8 = isinstance (quant_config , Fp8Config )
@@ -86,55 +88,79 @@ def __init__(
86
88
params_dtype = torch .get_default_dtype ()
87
89
self .params_dtype = params_dtype
88
90
91
+ # Gate always runs at half / full precision for now.
89
92
self .gate = ReplicatedLinear (self .hidden_size ,
90
93
self .num_total_experts ,
91
94
bias = False ,
92
95
params_dtype = self .params_dtype ,
93
96
quant_config = None )
94
97
95
- self .ws = nn .Parameter (
98
+ if self .use_fp8 :
99
+ params_dtype = torch .float8_e4m3fn
100
+
101
+ self .w13_weight = nn .Parameter (
96
102
torch .empty (self .num_total_experts ,
97
103
2 * self .intermediate_size ,
98
104
self .hidden_size ,
99
- dtype = self . params_dtype ))
100
- self .w2s = nn .Parameter (
105
+ dtype = params_dtype ))
106
+ self .w2_weight = nn .Parameter (
101
107
torch .empty (self .num_total_experts ,
102
108
self .hidden_size ,
103
109
self .intermediate_size ,
104
- dtype = self . params_dtype ))
110
+ dtype = params_dtype ))
105
111
106
- set_weight_attrs (self .ws , {
112
+ set_weight_attrs (self .w13_weight , {
107
113
"weight_loader" : self .weight_loader ,
108
114
})
109
- set_weight_attrs (self .w2s , {
115
+ set_weight_attrs (self .w2_weight , {
110
116
"weight_loader" : self .weight_loader ,
111
117
})
112
118
113
- # Scaling factors for FP8 weights
114
- self .ws_scale = nn .Parameter (
115
- torch .ones (self .num_total_experts , dtype = torch .float32 ),
116
- requires_grad = False ) if self .use_fp8 else None
117
- self .w2s_scale = nn .Parameter (
118
- torch .ones (self .num_total_experts , dtype = torch .float32 ),
119
- requires_grad = False ) if self .use_fp8 else None
120
-
121
- # Scaling factors for FP8 activations
122
- need_act_scales = (self .use_fp8
123
- and quant_config .activation_scheme == "static" )
124
- self .as_scale = nn .Parameter (
125
- torch .zeros (1 , dtype = torch .float32 ),
126
- requires_grad = False ) if need_act_scales else None
127
- self .a2s_scale = nn .Parameter (
128
- torch .zeros (1 , dtype = torch .float32 ),
129
- requires_grad = False ) if need_act_scales else None
130
-
131
- if need_act_scales :
132
- set_weight_attrs (self .as_scale , {
133
- "weight_loader" : self .weight_loader ,
134
- })
135
- set_weight_attrs (self .a2s_scale , {
136
- "weight_loader" : self .weight_loader ,
137
- })
119
+ # Used for fp8.
120
+ self .w13_scale = None
121
+ self .w2_scale = None
122
+ self .a13_scale = None
123
+ self .a2_scale = None
124
+
125
+ if self .use_fp8 :
126
+ # WEIGHT_SCALE (for fp8)
127
+ self .w13_scale = nn .Parameter (torch .ones (self .num_total_experts ,
128
+ dtype = torch .float32 ),
129
+ requires_grad = False )
130
+ self .w2_scale = nn .Parameter (torch .ones (self .num_total_experts ,
131
+ dtype = torch .float32 ),
132
+ requires_grad = False )
133
+
134
+ # If loading fp8 checkpoint, pass the weight loaders.
135
+ # If loading an fp16 checkpoint, do not (we will quantize in
136
+ # process_weights_after_loading()
137
+ if quant_config .is_checkpoint_fp8_serialized :
138
+ set_weight_attrs (self .w13_scale , {
139
+ "weight_loader" : self .weight_loader ,
140
+ })
141
+ set_weight_attrs (self .w2_scale , {
142
+ "weight_loader" : self .weight_loader ,
143
+ })
144
+
145
+ # ACT_SCALE (for fp8)
146
+ if quant_config .activation_scheme == "static" :
147
+ if not quant_config .is_checkpoint_fp8_serialized :
148
+ raise ValueError (
149
+ "Found static activation scheme for checkpoint that "
150
+ "was not serialized fp8." )
151
+ self .a13_scale = nn .Parameter (torch .zeros (
152
+ self .num_total_experts , dtype = torch .float32 ),
153
+ requires_grad = False )
154
+ self .a2_scale = nn .Parameter (torch .zeros (
155
+ self .num_total_experts , dtype = torch .float32 ),
156
+ requires_grad = False )
157
+
158
+ set_weight_attrs (self .a13_scale , {
159
+ "weight_loader" : self .weight_loader ,
160
+ })
161
+ set_weight_attrs (self .a2_scale , {
162
+ "weight_loader" : self .weight_loader ,
163
+ })
138
164
139
165
def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ,
140
166
weight_name : str , expert_id : int ):
@@ -149,38 +175,67 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
149
175
shard_size :2 * shard_size , :] = loaded_weight [shard , :]
150
176
if weight_name .endswith ("w2.weight" ):
151
177
param_data [expert_id , :, :] = loaded_weight [:, shard ]
152
- if "act_scale" in weight_name :
153
- param_data [: ] = param_data [:]. max ( loaded_weight )
178
+ if "act_scale" in weight_name or "weight_scale" in weight_name :
179
+ param_data [expert_id ] = loaded_weight
154
180
155
181
def process_weights_after_loading (self ):
156
- if self .use_fp8 :
157
- ws = torch .empty_like (self .ws .data , dtype = torch .float8_e4m3fn )
158
- w2s = torch .empty_like (self .w2s .data , dtype = torch .float8_e4m3fn )
182
+ # Fp8 is the only case where we need to process after loading.
183
+ if not self .use_fp8 :
184
+ return
185
+
186
+ # If checkpoint is fp16, quantize here.
187
+ if not self .quant_config .is_checkpoint_fp8_serialized :
188
+ w13_weight = torch .empty_like (self .w13_weight .data ,
189
+ dtype = torch .float8_e4m3fn )
190
+ w2_weight = torch .empty_like (self .w2_weight .data ,
191
+ dtype = torch .float8_e4m3fn )
159
192
for expert in range (self .num_total_experts ):
160
- ws [expert , :, :], self .ws_scale [expert ] = ops .scaled_fp8_quant (
161
- self .ws .data [expert , :, :])
162
- w2s [expert , :, :], self .w2s_scale [
163
- expert ] = ops .scaled_fp8_quant (self .w2s .data [expert , :, :])
164
- self .ws = nn .Parameter (ws , requires_grad = False )
165
- self .w2s = nn .Parameter (w2s , requires_grad = False )
193
+ w13_weight [expert , :, :], self .w13_scale [
194
+ expert ] = ops .scaled_fp8_quant (
195
+ self .w13_weight .data [expert , :, :])
196
+ w2_weight [expert , :, :], self .w2_scale [
197
+ expert ] = ops .scaled_fp8_quant (
198
+ self .w2_weight .data [expert , :, :])
199
+ self .w13_weight = nn .Parameter (w13_weight , requires_grad = False )
200
+ self .w2_weight = nn .Parameter (w2_weight , requires_grad = False )
201
+
202
+ # If checkpoint is fp8 + static, cleanup act_scales.
203
+ # Since state_dict has an act_scale per expert but our kernels
204
+ # are passed one act_scale shared across all experts.
205
+ elif self .quant_config .activation_scheme == "static" :
206
+ if self .a13_scale is None or self .a2_scale is None :
207
+ raise ValueError (
208
+ "QuantConfig has static quantization, but found "
209
+ "activation scales are None." )
210
+
211
+ if (not all_close_1d (self .a13_scale )
212
+ or not all_close_1d (self .a2_scale )):
213
+ print_warning_once (
214
+ "Found act_scales that are not equal for fp8 MoE layer. "
215
+ "Using the maximum across experts for each layer. " )
216
+
217
+ self .a13_scale = nn .Parameter (self .a13_scale .max (),
218
+ requires_grad = False )
219
+ self .a2_scale = nn .Parameter (self .a2_scale .max (),
220
+ requires_grad = False )
166
221
167
222
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
168
223
num_tokens , hidden_size = hidden_states .shape
169
224
hidden_states = hidden_states .view (- 1 , self .hidden_size )
170
225
# router_logits: (num_tokens, n_experts)
171
226
router_logits , _ = self .gate (hidden_states )
172
227
final_hidden_states = fused_moe (hidden_states ,
173
- self .ws ,
174
- self .w2s ,
228
+ self .w13_weight ,
229
+ self .w2_weight ,
175
230
router_logits ,
176
231
self .top_k ,
177
232
renormalize = True ,
178
233
inplace = True ,
179
234
use_fp8 = self .use_fp8 ,
180
- w1_scale = self .ws_scale ,
181
- w2_scale = self .w2s_scale ,
182
- a1_scale = self .as_scale ,
183
- a2_scale = self .a2s_scale )
235
+ w1_scale = self .w13_scale ,
236
+ w2_scale = self .w2_scale ,
237
+ a1_scale = self .a13_scale ,
238
+ a2_scale = self .a2_scale )
184
239
185
240
if self .tp_size > 1 :
186
241
final_hidden_states = tensor_model_parallel_all_reduce (
@@ -222,7 +277,9 @@ def __init__(self,
222
277
self .rope_theta = rope_theta
223
278
self .sliding_window = sliding_window
224
279
225
- if isinstance (quant_config , Fp8Config ):
280
+ if isinstance (
281
+ quant_config ,
282
+ Fp8Config ) and not quant_config .is_checkpoint_fp8_serialized :
226
283
print_warning_once (
227
284
"For Mixtral FP8 quantization, we currently do not quantize "
228
285
"the attention layers until their FP8 performance is improved."
@@ -461,16 +518,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
461
518
]
462
519
463
520
expert_params_mapping = [
521
+ # These are the weight scales for the experts
522
+ # (param_name, weight_name, expert_id)
523
+ ("w13_scale" if weight_name in ["w1" , "w3" ] else "w2_scale" ,
524
+ f"experts.{ expert_id } .{ weight_name } .weight_scale" , expert_id )
525
+ for expert_id in range (self .config .num_local_experts )
526
+ for weight_name in ["w1" , "w2" , "w3" ]
527
+ ] + [
464
528
# These are the weights for the experts
465
529
# (param_name, weight_name, expert_id)
466
- ("ws " if weight_name in ["w1" , "w3" ] else "w2s " ,
530
+ ("w13_weight " if weight_name in ["w1" , "w3" ] else "w2_weight " ,
467
531
f"experts.{ expert_id } .{ weight_name } .weight" , expert_id )
468
532
for expert_id in range (self .config .num_local_experts )
469
533
for weight_name in ["w1" , "w2" , "w3" ]
470
534
] + [
471
535
# These are the activation scales for the experts
472
536
# (param_name, weight_name, expert_id)
473
- ("as_scale " if weight_name in ["w1" , "w3" ] else "a2s_scale " ,
537
+ ("a13_scale " if weight_name in ["w1" , "w3" ] else "a2_scale " ,
474
538
f"experts.{ expert_id } .{ weight_name } .act_scale" , expert_id )
475
539
for expert_id in range (self .config .num_local_experts )
476
540
for weight_name in ["w1" , "w2" , "w3" ]
@@ -512,3 +576,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
512
576
weight_loader = getattr (param , "weight_loader" ,
513
577
default_weight_loader )
514
578
weight_loader (param , loaded_weight )
579
+
580
+
581
+ def all_close_1d (x : torch .Tensor ) -> bool :
582
+ assert len (x .shape ) == 1
583
+ return all (torch .allclose (x [0 ], x [i ]) for i in range (x .shape [0 ]))
0 commit comments