Skip to content

Commit 13fa9ca

Browse files
committed
BlockSparseMLP: Allow float32 output from down projection
1 parent 3254aa0 commit 13fa9ca

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

exllamav3/modules/block_sparse_mlp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
in_features = intermediate_size,
129129
out_features = hidden_size,
130130
qmap = qmap + f".{idx}.down",
131-
out_dtype = torch.half,
131+
out_dtype = self.out_dtype,
132132
allow_input_padding = True,
133133
)
134134

@@ -262,8 +262,9 @@ def forward(
262262
def mlp(exp_i, xc):
263263
g = self.gates[exp_i].forward(xc, params)
264264
u = self.ups[exp_i].forward(xc, params)
265-
self.activation_fn_call(g, u, u)
266-
return self.downs[exp_i].forward(u, params)
265+
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
266+
self.activation_fn_call(g, u, a)
267+
return self.downs[exp_i].forward(a, params)
267268

268269
for expert_idx in range(self.num_experts):
269270
if expert_count[expert_idx] == 0:

0 commit comments

Comments
 (0)