File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -128,7 +128,7 @@ def __init__(
128128 in_features = intermediate_size ,
129129 out_features = hidden_size ,
130130 qmap = qmap + f".{ idx } .down" ,
131- out_dtype = torch . half ,
131+ out_dtype = self . out_dtype ,
132132 allow_input_padding = True ,
133133 )
134134
@@ -262,8 +262,9 @@ def forward(
262262 def mlp (exp_i , xc ):
263263 g = self .gates [exp_i ].forward (xc , params )
264264 u = self .ups [exp_i ].forward (xc , params )
265- self .activation_fn_call (g , u , u )
266- return self .downs [exp_i ].forward (u , params )
265+ a = u if self .interm_dtype == torch .half else torch .empty_like (u , dtype = torch .half )
266+ self .activation_fn_call (g , u , a )
267+ return self .downs [exp_i ].forward (a , params )
267268
268269 for expert_idx in range (self .num_experts ):
269270 if expert_count [expert_idx ] == 0 :
You can’t perform that action at this time.
0 commit comments