1+ #pragma once
2+
3+ #include < ATen/Tensor.h>
4+ #include < vector>
5+ #include < pybind11/pybind11.h>
6+ namespace py = pybind11;
7+ #include " linear.h"
8+ #include " gated_rmsnorm.h"
9+
10+ struct BC_GatedDeltaNet
11+ {
12+ at::Tensor mixed_qkv;
13+ at::Tensor z;
14+ at::Tensor beta;
15+ at::Tensor g;
16+ at::Tensor qkvz;
17+ at::Tensor ba;
18+ at::Tensor conv_temp_a;
19+ at::Tensor conv_temp_b;
20+ at::Tensor core_attn_out;
21+ at::Tensor core_attn_out_f;
22+ std::shared_ptr<BC_LinearEXL3> qkvz_proj;
23+ std::shared_ptr<BC_LinearFP16> ba_proj;
24+ at::Tensor dt_bias;
25+ at::Tensor a_log;
26+ int num_k_heads;
27+ int num_v_heads;
28+ int k_head_dim;
29+ int v_head_dim;
30+ at::Tensor conv1d_weight;
31+ c10::optional<at::Tensor> conv1d_bias;
32+ std::shared_ptr<BC_GatedRMSNorm> norm;
33+ std::shared_ptr<BC_LinearEXL3> o_proj;
34+
35+ BC_GatedDeltaNet
36+ (
37+ at::Tensor _mixed_qkv,
38+ at::Tensor _z,
39+ at::Tensor _beta,
40+ at::Tensor _g,
41+ at::Tensor _qkvz,
42+ at::Tensor _ba,
43+ at::Tensor _conv_temp_a,
44+ at::Tensor _conv_temp_b,
45+ at::Tensor _core_attn_out,
46+ at::Tensor _core_attn_out_f,
47+ std::shared_ptr<BC_LinearEXL3> _qkvz_proj,
48+ std::shared_ptr<BC_LinearFP16> _ba_proj,
49+ at::Tensor _dt_bias,
50+ at::Tensor _a_log,
51+ int _num_k_heads,
52+ int _num_v_heads,
53+ int _k_head_dim,
54+ int _v_head_dim,
55+ at::Tensor _conv1d_weight,
56+ c10::optional<at::Tensor> _conv1d_bias,
57+ std::shared_ptr<BC_GatedRMSNorm> _norm,
58+ std::shared_ptr<BC_LinearEXL3> _o_proj
59+ ) :
60+ mixed_qkv (std::move(_mixed_qkv)),
61+ z (std::move(_z)),
62+ beta (std::move(_beta)),
63+ g (std::move(_g)),
64+ qkvz (std::move(_qkvz)),
65+ ba (std::move(_ba)),
66+ conv_temp_a (std::move(_conv_temp_a)),
67+ conv_temp_b (std::move(_conv_temp_b)),
68+ core_attn_out (std::move(_core_attn_out)),
69+ core_attn_out_f (std::move(_core_attn_out_f)),
70+ qkvz_proj (_qkvz_proj),
71+ ba_proj (_ba_proj),
72+ dt_bias (std::move(_dt_bias)),
73+ a_log (std::move(_a_log)),
74+ num_k_heads (_num_k_heads),
75+ num_v_heads (_num_v_heads),
76+ k_head_dim (_k_head_dim),
77+ v_head_dim (_v_head_dim),
78+ conv1d_weight (std::move(_conv1d_weight)),
79+ conv1d_bias (std::move(_conv1d_bias)),
80+ norm (_norm),
81+ o_proj (_o_proj)
82+ {}
83+
84+ at::Tensor run_bsz1_a
85+ (
86+ const at::Tensor& x
87+ );
88+
89+ void run_bsz1_b
90+ (
91+ at::Tensor& mixed_qkv,
92+ at::Tensor& y,
93+ at::Tensor& recurrent_state
94+ );
95+ };
0 commit comments