Skip to content

Commit 13dd99c

Browse files
committed
GatedDeltaNet: Move import
1 parent a8a3da5 commit 13dd99c

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "softcap.cuh"
1515
#include "routing.cuh"
1616
#include "gdn.cuh"
17+
#include "causal_conv1d.cuh"
1718

1819
#include "quant/quantize.cuh"
1920
#include "quant/pack.cuh"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <cuda_fp16.h>
2+
#include <cuda_fp16.hpp>
3+
#include "activation.cuh"
4+
#include <c10/cuda/CUDAGuard.h>
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include "util.h"
7+
#include "util.cuh"
8+
#include "compat.cuh"
9+
#include <cmath>
10+
11+
//__global__ __launch_bounds__(MAX_HEAD_DIM)
12+
//void causal_conv1d_kernel
13+
//(
14+
//
15+
//)
16+
//{
17+
//
18+
//}
19+
//
20+
21+
void causal_conv1d
22+
(
23+
24+
25+
26+
27+
28+
29+
)
30+
{
31+
32+
33+
}
34+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
void causal_conv1d
2+
(
3+
4+
);

exllamav3/modules/gated_delta_net.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def causal_conv1d_fwd_function_cu(
7979
causal_conv1d_update_function = causal_conv1d_update_function_torch
8080
causal_conv1d_fwd_function = causal_conv1d_fwd_function_torch
8181

82+
try:
83+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
84+
except ModuleNotFoundError:
85+
chunk_gated_delta_rule = None
86+
8287
"""
8388
fla wrapper, reduce overhead by bypassing input_guard and torch custom ops stuff
8489
"""
@@ -465,8 +470,6 @@ def forward(
465470
out_dtype: torch.dtype | None = None
466471
) -> torch.Tensor:
467472

468-
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
469-
470473
bsz, seqlen, _ = x.shape
471474

472475
# Previous state

0 commit comments

Comments
 (0)