We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 98e1c40 commit abd8ddcCopy full SHA for abd8ddc
exllamav3/modules/conv.py
@@ -81,13 +81,13 @@ def forward(
81
out_dtype: torch.dtype | None = None,
82
) -> torch.Tensor:
83
84
- bsz, seqlen, dim = x.shape
85
-
86
if self.dims == 2:
87
y = F.conv2d(x, self.weight, self.bias, self.kernel_size)
88
y = y.flatten(2).permute(0, 2, 1).contiguous()
89
90
elif self.dims == 3:
+ bsz, seqlen, dim = x.shape
+
91
if self.flat:
92
x_flat = x.view(-1, self.dim)
93
w_flat = self.weight.view(self.weight.shape[0], -1).T.contiguous()
0 commit comments