Skip to content

Commit 6cea948

Browse files
committed
More C++ modules
1 parent 13dd99c commit 6cea948

File tree

16 files changed

+596
-68
lines changed

16 files changed

+596
-68
lines changed

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@
4040
#include "parallel/gather.cuh"
4141
#include "parallel/all_reduce.cuh"
4242

43-
#include "libtorch/blocksparse_mlp.h"
4443
#include "libtorch/gated_delta_net.h"
4544
#include "libtorch/linear.h"
4645
#include "libtorch/gated_rmsnorm.h"
46+
#include "libtorch/mlp.h"
47+
#include "libtorch/blocksparse_mlp.h"
4748

4849
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
4950
{
@@ -124,4 +125,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
124125
#include "libtorch/linear_bc.h"
125126
#include "libtorch/gated_delta_net_bc.h"
126127
#include "libtorch/gated_rmsnorm_bc.h"
128+
#include "libtorch/mlp_bc.h"
129+
#include "libtorch/blocksparse_mlp_bc.h"
127130
}

exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <torch/extension.h>
66
#include "../util.h"
77
#include "../hgemm.cuh"
8+
#include "../quant/exl3_gemm.cuh"
9+
#include "../activation.cuh"
810

911
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
1012
int bsz,
@@ -56,4 +58,87 @@ std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
5658

5759
return {selected_experts, routing_weights};
5860
}
61+
}
62+
63+
void BC_BlockSparseMLP::run_bsz1
64+
(
65+
const at::Tensor& y,
66+
at::Tensor& selected_experts,
67+
at::Tensor& routing_weights
68+
)
69+
{
70+
py::gil_scoped_release _;
71+
const at::Tensor& yi = y.unsqueeze(0);
72+
73+
exl3_mgemm
74+
(
75+
yi,
76+
gate_ptrs_trellis,
77+
interm_g,
78+
gate_ptrs_suh,
79+
yh,
80+
gate_ptrs_svh,
81+
selected_experts,
82+
{},
83+
gate_K,
84+
-1,
85+
gate_mcg_mult,
86+
gate_mul1_mult,
87+
min_expert,
88+
max_expert
89+
);
90+
91+
exl3_mgemm(
92+
yi,
93+
up_ptrs_trellis,
94+
interm_u,
95+
up_ptrs_suh,
96+
yh,
97+
up_ptrs_svh,
98+
selected_experts,
99+
{},
100+
up_K,
101+
-1,
102+
up_mcg_mult,
103+
up_mul1_mult,
104+
min_expert,
105+
max_expert
106+
);
107+
108+
if (act_silu)
109+
silu_mul(interm_g, interm_u, interm_a);
110+
else if (act_gelu)
111+
gelu_mul(interm_g, interm_u, interm_a);
112+
113+
exl3_mgemm(
114+
interm_a,
115+
down_ptrs_trellis,
116+
out_d,
117+
down_ptrs_suh,
118+
interm_a,
119+
down_ptrs_svh,
120+
selected_experts,
121+
routing_weights,
122+
down_K,
123+
-1,
124+
down_mcg_mult,
125+
down_mul1_mult,
126+
min_expert,
127+
max_expert
128+
);
129+
130+
if (shared_experts)
131+
{
132+
shared_experts->run_bsz1(yi, out_d_sh.value());
133+
if (shared_gate)
134+
{
135+
shared_gate->run_cublas(yi, z.value());
136+
add_sigmoid_gate(out_d_sh.value(), z.value(), out_d);
137+
}
138+
else
139+
{
140+
out_d.add_(out_d_sh.value());
141+
}
142+
}
143+
59144
}

exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,123 @@
33
#include <ATen/Tensor.h>
44
#include <vector>
55
#include <pybind11/pybind11.h>
6-
76
namespace py = pybind11;
87

8+
#include "mlp.h"
9+
#include "linear.h"
10+
911
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
1012
int bsz,
1113
const py::object& cfg,
1214
const at::Tensor& y,
1315
const py::dict& params
14-
);
16+
);
17+
18+
struct BC_BlockSparseMLP
19+
{
20+
at::Tensor yh;
21+
at::Tensor interm_g;
22+
at::Tensor interm_u;
23+
at::Tensor interm_a;
24+
at::Tensor out_d;
25+
c10::optional<at::Tensor> out_d_sh;
26+
c10::optional<at::Tensor> z;
27+
int min_expert;
28+
int max_expert;
29+
at::Tensor gate_ptrs_trellis;
30+
at::Tensor gate_ptrs_suh;
31+
at::Tensor gate_ptrs_svh;
32+
int gate_K;
33+
uint32_t gate_mcg_mult;
34+
uint32_t gate_mul1_mult;
35+
at::Tensor up_ptrs_trellis;
36+
at::Tensor up_ptrs_suh;
37+
at::Tensor up_ptrs_svh;
38+
int up_K;
39+
uint32_t up_mcg_mult;
40+
uint32_t up_mul1_mult;
41+
at::Tensor down_ptrs_trellis;
42+
at::Tensor down_ptrs_suh;
43+
at::Tensor down_ptrs_svh;
44+
int down_K;
45+
uint32_t down_mcg_mult;
46+
uint32_t down_mul1_mult;
47+
bool act_silu;
48+
bool act_gelu;
49+
std::shared_ptr<BC_GatedMLP> shared_experts;
50+
std::shared_ptr<BC_LinearFP16> shared_gate;
51+
52+
BC_BlockSparseMLP
53+
(
54+
at::Tensor _yh,
55+
at::Tensor _interm_g,
56+
at::Tensor _interm_u,
57+
at::Tensor _interm_a,
58+
at::Tensor _out_d,
59+
c10::optional<at::Tensor> _out_d_sh,
60+
c10::optional<at::Tensor> _z,
61+
int _min_expert,
62+
int _max_expert,
63+
at::Tensor _gate_ptrs_trellis,
64+
at::Tensor _gate_ptrs_suh,
65+
at::Tensor _gate_ptrs_svh,
66+
int _gate_K,
67+
uint32_t _gate_mcg_mult,
68+
uint32_t _gate_mul1_mult,
69+
at::Tensor _up_ptrs_trellis,
70+
at::Tensor _up_ptrs_suh,
71+
at::Tensor _up_ptrs_svh,
72+
int _up_K,
73+
uint32_t _up_mcg_mult,
74+
uint32_t _up_mul1_mult,
75+
at::Tensor _down_ptrs_trellis,
76+
at::Tensor _down_ptrs_suh,
77+
at::Tensor _down_ptrs_svh,
78+
int _down_K,
79+
uint32_t _down_mcg_mult,
80+
uint32_t _down_mul1_mult,
81+
bool _act_silu,
82+
bool _act_gelu,
83+
std::shared_ptr<BC_GatedMLP> _shared_experts,
84+
std::shared_ptr<BC_LinearFP16> _shared_gate
85+
) :
86+
yh (std::move(_yh)),
87+
interm_g (std::move(_interm_g)),
88+
interm_u (std::move(_interm_u)),
89+
interm_a (std::move(_interm_a)),
90+
out_d (std::move(_out_d)),
91+
out_d_sh (std::move(_out_d_sh)),
92+
z (std::move(_z)),
93+
min_expert (_min_expert),
94+
max_expert (_max_expert),
95+
gate_ptrs_trellis (std::move(_gate_ptrs_trellis)),
96+
gate_ptrs_suh (std::move(_gate_ptrs_suh)),
97+
gate_ptrs_svh (std::move(_gate_ptrs_svh)),
98+
gate_K (_gate_K),
99+
gate_mcg_mult (_gate_mcg_mult),
100+
gate_mul1_mult (_gate_mul1_mult),
101+
up_ptrs_trellis (std::move(_up_ptrs_trellis)),
102+
up_ptrs_suh (std::move(_up_ptrs_suh)),
103+
up_ptrs_svh (std::move(_up_ptrs_svh)),
104+
up_K (_up_K),
105+
up_mcg_mult (_up_mcg_mult),
106+
up_mul1_mult (_up_mul1_mult),
107+
down_ptrs_trellis (std::move(_down_ptrs_trellis)),
108+
down_ptrs_suh (std::move(_down_ptrs_suh)),
109+
down_ptrs_svh (std::move(_down_ptrs_svh)),
110+
down_K (_down_K),
111+
down_mcg_mult (_down_mcg_mult),
112+
down_mul1_mult (_down_mul1_mult),
113+
act_silu (_act_silu),
114+
act_gelu (_act_gelu),
115+
shared_experts (_shared_experts),
116+
shared_gate (_shared_gate)
117+
{}
118+
119+
void run_bsz1
120+
(
121+
const at::Tensor& y,
122+
at::Tensor& selected_experts,
123+
at::Tensor& routing_weights
124+
);
125+
};
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSparseMLP").def
2+
(
3+
py::init<
4+
at::Tensor,
5+
at::Tensor,
6+
at::Tensor,
7+
at::Tensor,
8+
at::Tensor,
9+
c10::optional<at::Tensor>,
10+
c10::optional<at::Tensor>,
11+
int,
12+
int,
13+
at::Tensor,
14+
at::Tensor,
15+
at::Tensor,
16+
int,
17+
int,
18+
int,
19+
at::Tensor,
20+
at::Tensor,
21+
at::Tensor,
22+
int,
23+
int,
24+
int,
25+
at::Tensor,
26+
at::Tensor,
27+
at::Tensor,
28+
int,
29+
int,
30+
int,
31+
bool,
32+
bool,
33+
std::shared_ptr<BC_GatedMLP>,
34+
std::shared_ptr<BC_LinearFP16>
35+
>(),
36+
py::arg("yh"),
37+
py::arg("interm_g"),
38+
py::arg("interm_u"),
39+
py::arg("interm_a"),
40+
py::arg("out_d"),
41+
py::arg("out_d_sh"),
42+
py::arg("z"),
43+
py::arg("min_expert"),
44+
py::arg("max_expert"),
45+
py::arg("gate_ptrs_trellis"),
46+
py::arg("gate_ptrs_suh"),
47+
py::arg("gate_ptrs_svh"),
48+
py::arg("gate_K"),
49+
py::arg("gate_mcg_mult"),
50+
py::arg("gate_mul1_mult"),
51+
py::arg("up_ptrs_trellis"),
52+
py::arg("up_ptrs_suh"),
53+
py::arg("up_ptrs_svh"),
54+
py::arg("up_K"),
55+
py::arg("up_mcg_mult"),
56+
py::arg("up_mul1_mult"),
57+
py::arg("down_ptrs_trellis"),
58+
py::arg("down_ptrs_suh"),
59+
py::arg("down_ptrs_svh"),
60+
py::arg("down_K"),
61+
py::arg("down_mcg_mult"),
62+
py::arg("down_mul1_mult"),
63+
py::arg("act_silu"),
64+
py::arg("act_gelu"),
65+
py::arg("shared_experts"),
66+
py::arg("shared_gate")
67+
)
68+
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1);
Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
py::class_<BC_GatedRMSNorm, std::shared_ptr<BC_GatedRMSNorm>>(m, "BC_GatedRMSNorm").def
2-
(
3-
py::init<
4-
at::Tensor,
5-
float,
6-
float
7-
>(),
8-
py::arg("weight"),
9-
py::arg("rms_norm_eps"),
10-
py::arg("constant_bias")
11-
)
12-
.def("run", &BC_GatedRMSNorm::run);
2+
(
3+
py::init<
4+
at::Tensor,
5+
float,
6+
float
7+
>(),
8+
py::arg("weight"),
9+
py::arg("rms_norm_eps"),
10+
py::arg("constant_bias")
11+
)
12+
.def("run", &BC_GatedRMSNorm::run);

exllamav3/exllamav3_ext/libtorch/linear.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ void BC_LinearFP16::run(const at::Tensor& x, at::Tensor& y)
1919
}
2020

2121

22+
void BC_LinearFP16::run_cublas(const at::Tensor& x, at::Tensor& y)
23+
{
24+
hgemm(x, weight, y);
25+
if (bias)
26+
y.add_(bias.value());
27+
}
28+
29+
2230
void BC_LinearEXL3::run(const at::Tensor& x, at::Tensor& y)
2331
{
2432
if (x.numel() == x.size(-1))

exllamav3/exllamav3_ext/libtorch/linear.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct BC_LinearFP16
2020
{}
2121

2222
void run(const at::Tensor& x, at::Tensor& y);
23+
void run_cublas(const at::Tensor& x, at::Tensor& y);
2324
};
2425

2526
struct BC_LinearEXL3
@@ -29,8 +30,8 @@ struct BC_LinearEXL3
2930
at::Tensor svh;
3031
int K;
3132
c10::optional<at::Tensor> bias;
32-
int mcg_mult;
33-
int mul1_mult;
33+
uint32_t mcg_mult;
34+
uint32_t mul1_mult;
3435
at::Tensor xh;
3536

3637
BC_LinearEXL3
@@ -40,8 +41,8 @@ struct BC_LinearEXL3
4041
at::Tensor _svh,
4142
int _K,
4243
c10::optional<at::Tensor> _bias,
43-
int _mcg_mult,
44-
int _mul1_mult,
44+
uint32_t _mcg_mult,
45+
uint32_t _mul1_mult,
4546
at::Tensor _xh
4647
) :
4748
trellis(std::move(_trellis)),

0 commit comments

Comments
 (0)