Skip to content

Commit 6ec1b13

Browse files
committed
MLP: Fix edge case when hidden_size == interm_size (happens in some TP setups)
1 parent ce50be3 commit 6ec1b13

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

exllamav3/modules/mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,9 @@ def load_local(self, device: torch.Device, load_slice: int, **kwargs):
546546
if self.num_slices == 1 and self.multi_gu[0] is not None and self.downs[0].inner.bc is not None:
547547
mgu = self.multi_gu[0]
548548
self.bsz1_pa_args = [
549-
(device, (2, 1, self.hidden_size), self.interm_dtype),
550-
(device, (2, 1, mgu.out_features), self.interm_dtype),
551-
(device, (1, 1, 1, mgu.out_features), torch.half)
549+
(device, (2, 1, self.hidden_size), self.interm_dtype, "gu"),
550+
(device, (2, 1, mgu.out_features), self.interm_dtype, "a1"),
551+
(device, (1, 1, 1, mgu.out_features), torch.half, "a2")
552552
]
553553
self.bc = ext.BC_GatedMLP(
554554
*(g_tensor_cache.get(*arg) for arg in self.bsz1_pa_args),

0 commit comments

Comments
 (0)