Skip to content

Commit a8a3da5

Browse files
committed
Add C++ modules
1 parent 8c4a542 commit a8a3da5

File tree

16 files changed

+606
-132
lines changed

16 files changed

+606
-132
lines changed

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
#include "parallel/all_reduce.cuh"
4141

4242
#include "libtorch/blocksparse_mlp.h"
43+
#include "libtorch/gated_delta_net.h"
44+
#include "libtorch/linear.h"
45+
#include "libtorch/gated_rmsnorm.h"
4346

4447
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
4548
{
@@ -116,4 +119,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
116119
m.def("histogram", &histogram, "histogram");
117120

118121
m.def("blocksparse_mlp_routing", &blocksparse_mlp_routing, "blocksparse_mlp_routing");
122+
123+
#include "libtorch/linear_bc.h"
124+
#include "libtorch/gated_delta_net_bc.h"
125+
#include "libtorch/gated_rmsnorm_bc.h"
119126
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <Python.h>
2+
#include "gated_delta_net.h"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <torch/extension.h>
6+
#include "../util.h"
7+
#include "../hgemm.cuh"
8+
#include "../quant/exl3_gemm.cuh"
9+
#include "../gdn.cuh"
10+
11+
using namespace torch::indexing;
12+
13+
at::Tensor BC_GatedDeltaNet::run_bsz1_a
14+
(
15+
const at::Tensor& x
16+
)
17+
{
18+
py::gil_scoped_release _;
19+
20+
qkvz_proj->run(x, qkvz);
21+
ba_proj->run(x, ba);
22+
23+
gated_delta_net_fused_op
24+
(
25+
qkvz, ba,
26+
dt_bias, a_log,
27+
mixed_qkv, z, beta, g,
28+
num_k_heads,
29+
num_v_heads,
30+
k_head_dim,
31+
v_head_dim
32+
);
33+
34+
return mixed_qkv;
35+
}
36+
37+
void BC_GatedDeltaNet::run_bsz1_b
38+
(
39+
at::Tensor& mixed_qkv,
40+
at::Tensor& y,
41+
at::Tensor& recurrent_state
42+
)
43+
{
44+
cuda_recurrent_gated_delta_rule
45+
(
46+
mixed_qkv.transpose(1, 2),
47+
g,
48+
beta,
49+
recurrent_state,
50+
core_attn_out,
51+
num_k_heads,
52+
num_v_heads,
53+
k_head_dim,
54+
v_head_dim
55+
);
56+
57+
norm->run(core_attn_out, core_attn_out_f, z);
58+
o_proj->run(core_attn_out_f, y);
59+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <vector>
5+
#include <pybind11/pybind11.h>
6+
namespace py = pybind11;
7+
#include "linear.h"
8+
#include "gated_rmsnorm.h"
9+
10+
struct BC_GatedDeltaNet
11+
{
12+
at::Tensor mixed_qkv;
13+
at::Tensor z;
14+
at::Tensor beta;
15+
at::Tensor g;
16+
at::Tensor qkvz;
17+
at::Tensor ba;
18+
at::Tensor conv_temp_a;
19+
at::Tensor conv_temp_b;
20+
at::Tensor core_attn_out;
21+
at::Tensor core_attn_out_f;
22+
std::shared_ptr<BC_LinearEXL3> qkvz_proj;
23+
std::shared_ptr<BC_LinearFP16> ba_proj;
24+
at::Tensor dt_bias;
25+
at::Tensor a_log;
26+
int num_k_heads;
27+
int num_v_heads;
28+
int k_head_dim;
29+
int v_head_dim;
30+
at::Tensor conv1d_weight;
31+
c10::optional<at::Tensor> conv1d_bias;
32+
std::shared_ptr<BC_GatedRMSNorm> norm;
33+
std::shared_ptr<BC_LinearEXL3> o_proj;
34+
35+
BC_GatedDeltaNet
36+
(
37+
at::Tensor _mixed_qkv,
38+
at::Tensor _z,
39+
at::Tensor _beta,
40+
at::Tensor _g,
41+
at::Tensor _qkvz,
42+
at::Tensor _ba,
43+
at::Tensor _conv_temp_a,
44+
at::Tensor _conv_temp_b,
45+
at::Tensor _core_attn_out,
46+
at::Tensor _core_attn_out_f,
47+
std::shared_ptr<BC_LinearEXL3> _qkvz_proj,
48+
std::shared_ptr<BC_LinearFP16> _ba_proj,
49+
at::Tensor _dt_bias,
50+
at::Tensor _a_log,
51+
int _num_k_heads,
52+
int _num_v_heads,
53+
int _k_head_dim,
54+
int _v_head_dim,
55+
at::Tensor _conv1d_weight,
56+
c10::optional<at::Tensor> _conv1d_bias,
57+
std::shared_ptr<BC_GatedRMSNorm> _norm,
58+
std::shared_ptr<BC_LinearEXL3> _o_proj
59+
) :
60+
mixed_qkv (std::move(_mixed_qkv)),
61+
z (std::move(_z)),
62+
beta (std::move(_beta)),
63+
g (std::move(_g)),
64+
qkvz (std::move(_qkvz)),
65+
ba (std::move(_ba)),
66+
conv_temp_a (std::move(_conv_temp_a)),
67+
conv_temp_b (std::move(_conv_temp_b)),
68+
core_attn_out (std::move(_core_attn_out)),
69+
core_attn_out_f (std::move(_core_attn_out_f)),
70+
qkvz_proj (_qkvz_proj),
71+
ba_proj (_ba_proj),
72+
dt_bias (std::move(_dt_bias)),
73+
a_log (std::move(_a_log)),
74+
num_k_heads (_num_k_heads),
75+
num_v_heads (_num_v_heads),
76+
k_head_dim (_k_head_dim),
77+
v_head_dim (_v_head_dim),
78+
conv1d_weight (std::move(_conv1d_weight)),
79+
conv1d_bias (std::move(_conv1d_bias)),
80+
norm (_norm),
81+
o_proj (_o_proj)
82+
{}
83+
84+
at::Tensor run_bsz1_a
85+
(
86+
const at::Tensor& x
87+
);
88+
89+
void run_bsz1_b
90+
(
91+
at::Tensor& mixed_qkv,
92+
at::Tensor& y,
93+
at::Tensor& recurrent_state
94+
);
95+
};
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
py::class_<BC_GatedDeltaNet, std::shared_ptr<BC_GatedDeltaNet>>(m, "BC_GatedDeltaNet").def
2+
(
3+
py::init<
4+
at::Tensor,
5+
at::Tensor,
6+
at::Tensor,
7+
at::Tensor,
8+
at::Tensor,
9+
at::Tensor,
10+
at::Tensor,
11+
at::Tensor,
12+
at::Tensor,
13+
at::Tensor,
14+
std::shared_ptr<BC_LinearEXL3>,
15+
std::shared_ptr<BC_LinearFP16>,
16+
at::Tensor,
17+
at::Tensor,
18+
int,
19+
int,
20+
int,
21+
int,
22+
at::Tensor,
23+
c10::optional<at::Tensor>,
24+
std::shared_ptr<BC_GatedRMSNorm>,
25+
std::shared_ptr<BC_LinearEXL3>
26+
>(),
27+
py::arg("mixed_qkv"),
28+
py::arg("z"),
29+
py::arg("beta"),
30+
py::arg("g"),
31+
py::arg("qkvz"),
32+
py::arg("ba"),
33+
py::arg("conv_temp_a"),
34+
py::arg("conv_temp_b"),
35+
py::arg("core_attn_out"),
36+
py::arg("core_attn_out_f"),
37+
py::arg("qkvz_proj"),
38+
py::arg("ba_proj"),
39+
py::arg("dt_bias"),
40+
py::arg("a_log"),
41+
py::arg("num_k_heads"),
42+
py::arg("num_v_heads"),
43+
py::arg("k_head_dim"),
44+
py::arg("v_head_dim"),
45+
py::arg("conv1d_weight"),
46+
py::arg("conv1d_bias"),
47+
py::arg("norm"),
48+
py::arg("o_proj")
49+
)
50+
.def("run_bsz1_a", &BC_GatedDeltaNet::run_bsz1_a)
51+
.def("run_bsz1_b", &BC_GatedDeltaNet::run_bsz1_b);
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include <Python.h>
2+
#include "gated_rmsnorm.h"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <torch/extension.h>
6+
#include "../util.h"
7+
#include "../norm.cuh"
8+
9+
void BC_GatedRMSNorm::run(const at::Tensor& x, at::Tensor& y, const at::Tensor& gate)
10+
{
11+
gated_rms_norm(x, weight, y, gate, rms_norm_eps, constant_bias);
12+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <vector>
5+
#include <pybind11/pybind11.h>
6+
namespace py = pybind11;
7+
8+
struct BC_GatedRMSNorm
9+
{
10+
at::Tensor weight;
11+
float rms_norm_eps;
12+
float constant_bias;
13+
14+
BC_GatedRMSNorm
15+
(
16+
at::Tensor _weight,
17+
float _rms_norm_eps,
18+
float _constant_bias
19+
) :
20+
weight(std::move(_weight)),
21+
rms_norm_eps(_rms_norm_eps),
22+
constant_bias(_constant_bias)
23+
{}
24+
25+
void run(const at::Tensor& x, at::Tensor& y, const at::Tensor& gate);
26+
};
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
py::class_<BC_GatedRMSNorm, std::shared_ptr<BC_GatedRMSNorm>>(m, "BC_GatedRMSNorm").def
2+
(
3+
py::init<
4+
at::Tensor,
5+
float,
6+
float
7+
>(),
8+
py::arg("weight"),
9+
py::arg("rms_norm_eps"),
10+
py::arg("constant_bias")
11+
)
12+
.def("run", &BC_GatedRMSNorm::run);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <Python.h>
2+
#include "linear.h"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <torch/extension.h>
6+
#include "../util.h"
7+
#include "../hgemm.cuh"
8+
#include "../quant/exl3_gemm.cuh"
9+
10+
void BC_LinearFP16::run(const at::Tensor& x, at::Tensor& y)
11+
{
12+
if (x.dtype() == y.dtype())
13+
at::matmul_out(weight, x, y);
14+
else
15+
hgemm(x, weight, y);
16+
17+
if (bias)
18+
y.add_(bias.value());
19+
}
20+
21+
22+
void BC_LinearEXL3::run(const at::Tensor& x, at::Tensor& y)
23+
{
24+
if (x.numel() == x.size(-1))
25+
{
26+
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult);
27+
}
28+
else
29+
{
30+
at::Tensor xh_ = at::empty_like(x);
31+
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult);
32+
}
33+
34+
if (bias) y.add_(bias.value());
35+
}
36+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <vector>
5+
#include <pybind11/pybind11.h>
6+
namespace py = pybind11;
7+
8+
struct BC_LinearFP16
9+
{
10+
at::Tensor weight;
11+
c10::optional<at::Tensor> bias;
12+
13+
BC_LinearFP16
14+
(
15+
at::Tensor _weight,
16+
c10::optional<at::Tensor> _bias
17+
) :
18+
weight(std::move(_weight)),
19+
bias(std::move(_bias))
20+
{}
21+
22+
void run(const at::Tensor& x, at::Tensor& y);
23+
};
24+
25+
struct BC_LinearEXL3
26+
{
27+
at::Tensor trellis;
28+
at::Tensor suh;
29+
at::Tensor svh;
30+
int K;
31+
c10::optional<at::Tensor> bias;
32+
int mcg_mult;
33+
int mul1_mult;
34+
at::Tensor xh;
35+
36+
BC_LinearEXL3
37+
(
38+
at::Tensor _trellis,
39+
at::Tensor _suh,
40+
at::Tensor _svh,
41+
int _K,
42+
c10::optional<at::Tensor> _bias,
43+
int _mcg_mult,
44+
int _mul1_mult,
45+
at::Tensor _xh
46+
) :
47+
trellis(std::move(_trellis)),
48+
suh(std::move(_suh)),
49+
svh(std::move(_svh)),
50+
K(_K),
51+
bias(std::move(_bias)),
52+
mcg_mult(_mcg_mult),
53+
mul1_mult(_mul1_mult),
54+
xh(std::move(_xh))
55+
{}
56+
57+
void run(const at::Tensor& x, at::Tensor& y);
58+
};

0 commit comments

Comments
 (0)