File tree Expand file tree Collapse file tree 4 files changed +44
-2
lines changed
Expand file tree Collapse file tree 4 files changed +44
-2
lines changed Original file line number Diff line number Diff line change 1414#include " softcap.cuh"
1515#include " routing.cuh"
1616#include " gdn.cuh"
17+ #include " causal_conv1d.cuh"
1718
1819#include " quant/quantize.cuh"
1920#include " quant/pack.cuh"
Original file line number Diff line number Diff line change 1+ #include < cuda_fp16.h>
2+ #include < cuda_fp16.hpp>
3+ #include " activation.cuh"
4+ #include < c10/cuda/CUDAGuard.h>
5+ #include < ATen/cuda/CUDAContext.h>
6+ #include " util.h"
7+ #include " util.cuh"
8+ #include " compat.cuh"
9+ #include < cmath>
10+
11+ // __global__ __launch_bounds__(MAX_HEAD_DIM)
12+ // void causal_conv1d_kernel
13+ // (
14+ //
15+ // )
16+ // {
17+ //
18+ // }
19+ //
20+
21+ void causal_conv1d
22+ (
23+
24+
25+
26+
27+
28+
29+ )
30+ {
31+
32+
33+ }
34+
Original file line number Diff line number Diff line change 1+ void causal_conv1d
2+ (
3+
4+ );
Original file line number Diff line number Diff line change @@ -79,6 +79,11 @@ def causal_conv1d_fwd_function_cu(
7979 causal_conv1d_update_function = causal_conv1d_update_function_torch
8080 causal_conv1d_fwd_function = causal_conv1d_fwd_function_torch
8181
82+ try :
83+ from fla .ops .gated_delta_rule import chunk_gated_delta_rule
84+ except ModuleNotFoundError :
85+ chunk_gated_delta_rule = None
86+
8287"""
8388fla wrapper, reduce overhead by bypassing input_guard and torch custom ops stuff
8489"""
@@ -465,8 +470,6 @@ def forward(
465470 out_dtype : torch .dtype | None = None
466471 ) -> torch .Tensor :
467472
468- from fla .ops .gated_delta_rule import chunk_gated_delta_rule
469-
470473 bsz , seqlen , _ = x .shape
471474
472475 # Previous state
You can’t perform that action at this time.
0 commit comments