We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 58f36bf commit bbfd4e7Copy full SHA for bbfd4e7
exllamav3/architecture/gemma3.py
@@ -387,6 +387,9 @@ def __init__(
387
self.patches_per_image = patches_per_image
388
self.tokens_per_side = tokens_per_side
389
390
+ def optimizer_targets(self):
391
+ raise NotImplementedError()
392
+
393
@override
394
def load(self, device: torch.device, **kwargs):
395
pass
exllamav3/architecture/mistral3.py
@@ -256,6 +256,8 @@ def __init__(
256
257
self.register_submodule(self.merging_layer)
258
259
260
261
262
263
def forward(
exllamav3/exllamav3_ext/gnd.cu
@@ -54,7 +54,8 @@ void gated_delta_net_fused_op_kernel
54
size_t Ng,
55
size_t Hk,
56
size_t Hv
57
-){
+)
58
+{
59
const size_t Nv = Nk * Ng;
60
const size_t Fseg = 2 * Hk + 2 * Ng * Hv; // per-khead segment in mixed_qkvz
61
const size_t Fba = 2 * Ng; // per-khead segment in mixed_ba
exllamav3/exllamav3_ext/hgemm.cu
@@ -28,19 +28,19 @@ void hgemm
28
29
TORCH_CHECK_DTYPE(a, kHalf);
30
TORCH_CHECK_DTYPE(b, kHalf);
31
- TORCH_CHECK_DIM(a, 2);
32
TORCH_CHECK_DIM(b, 2);
33
- TORCH_CHECK_DIM(c, 2);
34
- TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
35
- TORCH_CHECK_SHAPES(a, 1, b, 0, 1);
36
- TORCH_CHECK_SHAPES(b, 1, c, 1, 1);
+ // TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
+ TORCH_CHECK_SHAPES(a, -1, b, 0, 1);
+ TORCH_CHECK_SHAPES(b, 1, c, -1, 1);
37
38
const half* a_ptr = (const half*) a.data_ptr();
39
const half* b_ptr = (const half*) b.data_ptr();
40
41
- int size_m = a.size(0);
42
- int size_k = a.size(1);
43
- int size_n = b.size(1);
+ int size_m = 1;
+ int dim = a.dim();
+ for (int d = 0; d < dim - 1; ++d) size_m *= a.size(d);
+ int size_k = a.size(-1);
+ int size_n = b.size(-1);
44
45
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
46
cublasSetStream(cublas_handle, stream);
exllamav3/exllamav3_ext/norm.cu
@@ -362,7 +362,7 @@ void gated_rms_norm
362
TORCH_CHECK_DTYPE(g, kBFloat16);
363
TORCH_CHECK_DIV(x, -1, 4);
364
TORCH_CHECK_SHAPES(x, -1, w, 0, 1);
365
- TORCH_CHECK_SHAPES_FULL(x, y);
+ // TORCH_CHECK_SHAPES_FULL(x, y);
366
TORCH_CHECK_SHAPES_FULL(x, g);
367
368
bool output_fp32 = y.dtype() == at::kFloat;
exllamav3/exllamav3_ext/quant/exl3_gemm.cu
@@ -46,7 +46,7 @@ int exl3_gemm
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
47
48
TORCH_CHECK_DIM(B, 3);
49
- TORCH_CHECK_SHAPES(A, 1, B, 0, 16);
+ TORCH_CHECK_SHAPES(A, -1, B, 0, 16);
50
TORCH_CHECK_SHAPES(C, -1, B, 1, 16);
51
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
52
TORCH_CHECK_DTYPE(A, kHalf);
@@ -82,8 +82,11 @@ int exl3_gemm
82
const half* A_ptr = (const half*) A.data_ptr();
83
const uint16_t* B_ptr = (const uint16_t*) B.data_ptr();
84
void* C_ptr = (void*) C.data_ptr();
85
- int size_m = A.size(0);
86
- int size_k = A.size(1);
87
+ int dim = A.dim();
88
+ for (int d = 0; d < dim - 1; ++d) size_m *= A.size(d);
89
+ int size_k = A.size(-1);
90
int size_n = B.size(1) * 16;
91
92
// Select kernel
exllamav3/modules/gated_delta_net.py
@@ -1,22 +1,16 @@
1
from __future__ import annotations
2
-from dataclasses import dataclass
3
from typing_extensions import override
4
import torch
5
import torch.nn.functional as F
6
from ..model.config import Config
7
-from ..util.rope import RopeSettings, RoPE
8
from ..util.tensor import get_for_device, to2
9
-from . import Module, Linear, RMSNorm, LayerNorm
10
-from ..constants import PAGE_SIZE
11
-from ..cache import Cache
12
-from flash_attn import flash_attn_func, flash_attn_with_kvcache
+from . import Module, Linear
13
from ..util import profile_opt
14
-from .multilinear import MultiLinear
15
from ..ext import exllamav3_ext as ext
16
from ..model.model_tp_alloc import TPAllocation
17
-import torch.distributed as dist
18
from .gated_rmsnorm import GatedRMSNorm
19
from ..cache import CacheableState
+from ..util.tensor import g_tensor_cache
20
21
"""
22
causal_conv1d wrappers and fallback functions
0 commit comments