Skip to content

Commit bbfd4e7

Browse files
committed
Cleanup
1 parent 58f36bf commit bbfd4e7

File tree

7 files changed

+24
-21
lines changed

7 files changed

+24
-21
lines changed

exllamav3/architecture/gemma3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def __init__(
387387
self.patches_per_image = patches_per_image
388388
self.tokens_per_side = tokens_per_side
389389

390+
def optimizer_targets(self):
391+
raise NotImplementedError()
392+
390393
@override
391394
def load(self, device: torch.device, **kwargs):
392395
pass

exllamav3/architecture/mistral3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def __init__(
256256

257257
self.register_submodule(self.merging_layer)
258258

259+
def optimizer_targets(self):
260+
raise NotImplementedError()
259261

260262
@override
261263
def forward(

exllamav3/exllamav3_ext/gnd.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ void gated_delta_net_fused_op_kernel
5454
size_t Ng,
5555
size_t Hk,
5656
size_t Hv
57-
){
57+
)
58+
{
5859
const size_t Nv = Nk * Ng;
5960
const size_t Fseg = 2 * Hk + 2 * Ng * Hv; // per-khead segment in mixed_qkvz
6061
const size_t Fba = 2 * Ng; // per-khead segment in mixed_ba

exllamav3/exllamav3_ext/hgemm.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ void hgemm
2828

2929
TORCH_CHECK_DTYPE(a, kHalf);
3030
TORCH_CHECK_DTYPE(b, kHalf);
31-
TORCH_CHECK_DIM(a, 2);
3231
TORCH_CHECK_DIM(b, 2);
33-
TORCH_CHECK_DIM(c, 2);
34-
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
35-
TORCH_CHECK_SHAPES(a, 1, b, 0, 1);
36-
TORCH_CHECK_SHAPES(b, 1, c, 1, 1);
32+
// TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
33+
TORCH_CHECK_SHAPES(a, -1, b, 0, 1);
34+
TORCH_CHECK_SHAPES(b, 1, c, -1, 1);
3735

3836
const half* a_ptr = (const half*) a.data_ptr();
3937
const half* b_ptr = (const half*) b.data_ptr();
4038

41-
int size_m = a.size(0);
42-
int size_k = a.size(1);
43-
int size_n = b.size(1);
39+
int size_m = 1;
40+
int dim = a.dim();
41+
for (int d = 0; d < dim - 1; ++d) size_m *= a.size(d);
42+
int size_k = a.size(-1);
43+
int size_n = b.size(-1);
4444

4545
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
4646
cublasSetStream(cublas_handle, stream);

exllamav3/exllamav3_ext/norm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ void gated_rms_norm
362362
TORCH_CHECK_DTYPE(g, kBFloat16);
363363
TORCH_CHECK_DIV(x, -1, 4);
364364
TORCH_CHECK_SHAPES(x, -1, w, 0, 1);
365-
TORCH_CHECK_SHAPES_FULL(x, y);
365+
// TORCH_CHECK_SHAPES_FULL(x, y);
366366
TORCH_CHECK_SHAPES_FULL(x, g);
367367

368368
bool output_fp32 = y.dtype() == at::kFloat;

exllamav3/exllamav3_ext/quant/exl3_gemm.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ int exl3_gemm
4646
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
4747

4848
TORCH_CHECK_DIM(B, 3);
49-
TORCH_CHECK_SHAPES(A, 1, B, 0, 16);
49+
TORCH_CHECK_SHAPES(A, -1, B, 0, 16);
5050
TORCH_CHECK_SHAPES(C, -1, B, 1, 16);
5151
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
5252
TORCH_CHECK_DTYPE(A, kHalf);
@@ -82,8 +82,11 @@ int exl3_gemm
8282
const half* A_ptr = (const half*) A.data_ptr();
8383
const uint16_t* B_ptr = (const uint16_t*) B.data_ptr();
8484
void* C_ptr = (void*) C.data_ptr();
85-
int size_m = A.size(0);
86-
int size_k = A.size(1);
85+
86+
int size_m = 1;
87+
int dim = A.dim();
88+
for (int d = 0; d < dim - 1; ++d) size_m *= A.size(d);
89+
int size_k = A.size(-1);
8790
int size_n = B.size(1) * 16;
8891

8992
// Select kernel

exllamav3/modules/gated_delta_net.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
from __future__ import annotations
2-
from dataclasses import dataclass
32
from typing_extensions import override
43
import torch
54
import torch.nn.functional as F
65
from ..model.config import Config
7-
from ..util.rope import RopeSettings, RoPE
86
from ..util.tensor import get_for_device, to2
9-
from . import Module, Linear, RMSNorm, LayerNorm
10-
from ..constants import PAGE_SIZE
11-
from ..cache import Cache
12-
from flash_attn import flash_attn_func, flash_attn_with_kvcache
7+
from . import Module, Linear
138
from ..util import profile_opt
14-
from .multilinear import MultiLinear
159
from ..ext import exllamav3_ext as ext
1610
from ..model.model_tp_alloc import TPAllocation
17-
import torch.distributed as dist
1811
from .gated_rmsnorm import GatedRMSNorm
1912
from ..cache import CacheableState
13+
from ..util.tensor import g_tensor_cache
2014

2115
"""
2216
causal_conv1d wrappers and fallback functions

0 commit comments

Comments
 (0)