33#include < ATen/Tensor.h>
44#include < vector>
55#include < pybind11/pybind11.h>
6-
76namespace py = pybind11;
87
8+ #include " mlp.h"
9+ #include " linear.h"
10+
911std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing (
1012 int bsz,
1113 const py::object& cfg,
1214 const at::Tensor& y,
1315 const py::dict& params
14- );
16+ );
17+
18+ struct BC_BlockSparseMLP
19+ {
20+ at::Tensor yh;
21+ at::Tensor interm_g;
22+ at::Tensor interm_u;
23+ at::Tensor interm_a;
24+ at::Tensor out_d;
25+ c10::optional<at::Tensor> out_d_sh;
26+ c10::optional<at::Tensor> z;
27+ int min_expert;
28+ int max_expert;
29+ at::Tensor gate_ptrs_trellis;
30+ at::Tensor gate_ptrs_suh;
31+ at::Tensor gate_ptrs_svh;
32+ int gate_K;
33+ uint32_t gate_mcg_mult;
34+ uint32_t gate_mul1_mult;
35+ at::Tensor up_ptrs_trellis;
36+ at::Tensor up_ptrs_suh;
37+ at::Tensor up_ptrs_svh;
38+ int up_K;
39+ uint32_t up_mcg_mult;
40+ uint32_t up_mul1_mult;
41+ at::Tensor down_ptrs_trellis;
42+ at::Tensor down_ptrs_suh;
43+ at::Tensor down_ptrs_svh;
44+ int down_K;
45+ uint32_t down_mcg_mult;
46+ uint32_t down_mul1_mult;
47+ bool act_silu;
48+ bool act_gelu;
49+ std::shared_ptr<BC_GatedMLP> shared_experts;
50+ std::shared_ptr<BC_LinearFP16> shared_gate;
51+
52+ BC_BlockSparseMLP
53+ (
54+ at::Tensor _yh,
55+ at::Tensor _interm_g,
56+ at::Tensor _interm_u,
57+ at::Tensor _interm_a,
58+ at::Tensor _out_d,
59+ c10::optional<at::Tensor> _out_d_sh,
60+ c10::optional<at::Tensor> _z,
61+ int _min_expert,
62+ int _max_expert,
63+ at::Tensor _gate_ptrs_trellis,
64+ at::Tensor _gate_ptrs_suh,
65+ at::Tensor _gate_ptrs_svh,
66+ int _gate_K,
67+ uint32_t _gate_mcg_mult,
68+ uint32_t _gate_mul1_mult,
69+ at::Tensor _up_ptrs_trellis,
70+ at::Tensor _up_ptrs_suh,
71+ at::Tensor _up_ptrs_svh,
72+ int _up_K,
73+ uint32_t _up_mcg_mult,
74+ uint32_t _up_mul1_mult,
75+ at::Tensor _down_ptrs_trellis,
76+ at::Tensor _down_ptrs_suh,
77+ at::Tensor _down_ptrs_svh,
78+ int _down_K,
79+ uint32_t _down_mcg_mult,
80+ uint32_t _down_mul1_mult,
81+ bool _act_silu,
82+ bool _act_gelu,
83+ std::shared_ptr<BC_GatedMLP> _shared_experts,
84+ std::shared_ptr<BC_LinearFP16> _shared_gate
85+ ) :
86+ yh (std::move(_yh)),
87+ interm_g (std::move(_interm_g)),
88+ interm_u (std::move(_interm_u)),
89+ interm_a (std::move(_interm_a)),
90+ out_d (std::move(_out_d)),
91+ out_d_sh (std::move(_out_d_sh)),
92+ z (std::move(_z)),
93+ min_expert (_min_expert),
94+ max_expert (_max_expert),
95+ gate_ptrs_trellis (std::move(_gate_ptrs_trellis)),
96+ gate_ptrs_suh (std::move(_gate_ptrs_suh)),
97+ gate_ptrs_svh (std::move(_gate_ptrs_svh)),
98+ gate_K (_gate_K),
99+ gate_mcg_mult (_gate_mcg_mult),
100+ gate_mul1_mult (_gate_mul1_mult),
101+ up_ptrs_trellis (std::move(_up_ptrs_trellis)),
102+ up_ptrs_suh (std::move(_up_ptrs_suh)),
103+ up_ptrs_svh (std::move(_up_ptrs_svh)),
104+ up_K (_up_K),
105+ up_mcg_mult (_up_mcg_mult),
106+ up_mul1_mult (_up_mul1_mult),
107+ down_ptrs_trellis (std::move(_down_ptrs_trellis)),
108+ down_ptrs_suh (std::move(_down_ptrs_suh)),
109+ down_ptrs_svh (std::move(_down_ptrs_svh)),
110+ down_K (_down_K),
111+ down_mcg_mult (_down_mcg_mult),
112+ down_mul1_mult (_down_mul1_mult),
113+ act_silu (_act_silu),
114+ act_gelu (_act_gelu),
115+ shared_experts (_shared_experts),
116+ shared_gate (_shared_gate)
117+ {}
118+
119+ void run_bsz1
120+ (
121+ const at::Tensor& y,
122+ at::Tensor& selected_experts,
123+ at::Tensor& routing_weights
124+ );
125+ };
0 commit comments