Skip to content

Commit abd8ddc

Browse files
committed
Conv: Fix regression for 2D input
1 parent 98e1c40 commit abd8ddc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

exllamav3/modules/conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ def forward(
8181
out_dtype: torch.dtype | None = None,
8282
) -> torch.Tensor:
8383

84-
bsz, seqlen, dim = x.shape
85-
8684
if self.dims == 2:
8785
y = F.conv2d(x, self.weight, self.bias, self.kernel_size)
8886
y = y.flatten(2).permute(0, 2, 1).contiguous()
8987

9088
elif self.dims == 3:
89+
bsz, seqlen, dim = x.shape
90+
9191
if self.flat:
9292
x_flat = x.view(-1, self.dim)
9393
w_flat = self.weight.view(self.weight.shape[0], -1).T.contiguous()

0 commit comments

Comments
 (0)