@@ -32,16 +32,18 @@ int main(int argc, char const *argv[]) {
3232 const int kH = testu::params::H;
3333 const int kTu = testu::params::Tu;
3434 const int kNTu = testu::params::MaxNumTu;
35- const int kZTu = testu::params::ZTu;
35+ const int kZTu = 8 ; // testu::params::ZTu;
3636 const int kNTv = testu::params::MaxNumTv;
3737 const int kZTv = testu::params::ZTv;
3838
3939 const int kNumActiveInputs = 1 ; // testu::params::N;
40- const int kInputSize_tmp = testu::params::I / 1 ;
40+ const int kInputSize_tmp = testu::params::I / 16 ;
4141 const int kInputSize = (kInputSize_tmp > testu::params::I) ? testu::params::I : kInputSize_tmp ;
4242 const int kNumTilesU = kInputSize / testu::params::Tu;
4343
4444 typedef typename testu::params::ActivationD ActivationType;
45+ typedef ap_uint<testu::params::NumGTuBitsAligned> IndexType;
46+
4547 typedef hls::vector<ActivationType, testu::params::N> VectN_Type;
4648 typedef hls::vector<ActivationType, testu::params::G> VectG_Type;
4749 typedef hls::vector<ActivationType, testu::params::Tu> VectTuAct_Type;
@@ -105,23 +107,21 @@ int main(int argc, char const *argv[]) {
105107 auto f_weight = f_gate->fix_data ();
106108 auto c_weight = c_gate->fix_data ();
107109 auto o_weight = o_gate->fix_data ();
110+ auto i_weight_pruned = i_gate->fix_pruned_data ();
111+ auto f_weight_pruned = f_gate->fix_pruned_data ();
112+ auto c_weight_pruned = c_gate->fix_pruned_data ();
113+ auto o_weight_pruned = o_gate->fix_pruned_data ();
108114 for (int i = 0 ; i < max_num_refinements; ++i) {
109115 for (int j = 0 ; j < kInputSize ; ++j) {
116+ // std::cout << i_weight[i * kInputSize + j] << " ";
110117 for (int ii = 0 ; ii < testu::params::N; ++ii) {
111118 xu[i][ii][0 ] += i_weight[i * kInputSize + j] * storage.get_fix_x (ii)[j];
112119 xu[i][ii][1 ] += f_weight[i * kInputSize + j] * storage.get_fix_x (ii)[j];
113120 xu[i][ii][2 ] += c_weight[i * kInputSize + j] * storage.get_fix_x (ii)[j];
114121 xu[i][ii][3 ] += o_weight[i * kInputSize + j] * storage.get_fix_x (ii)[j];
115122 }
116123 }
117- }
118- std::cout << " [INFO] Generating gold results." << std::endl;
119- for (int i = 0 ; i < max_num_refinements; ++i) {
120- for (int j = 0 ; j < testu::params::N; ++j) {
121- for (int k = 0 ; k < testu::params::G; ++k) {
122- // xu_gold[i * testu::params::G + k][j] = xu[i][j][k];
123- }
124- }
124+ // std::cout << std::endl;
125125 }
126126
127127#if 1
@@ -149,27 +149,45 @@ int main(int argc, char const *argv[]) {
149149 for (int j = 0 ; j < kNumTilesU - kZTu ; ++j) {
150150 VectTuAct_Type u_val;
151151 for (int k = 0 ; k < testu::params::Tu; ++k) {
152- u_val[k] = i_weight[i * kInputSize + i_gate->get_nz_idx (i, j) * kTu + k];
152+ // u_val[k] = i_weight[i * kInputSize + i_gate->get_nz_idx(i, j) * kTu + k];
153+ u_val[k] = i_weight_pruned[i * kInputSize + j * kTu + k];
153154 }
154155 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
155156 for (int k = 0 ; k < testu::params::Tu; ++k) {
156- u_val[k] = f_weight[i * kInputSize + f_gate->get_nz_idx (i, j) * kTu + k];
157+ // u_val[k] = f_weight[i * kInputSize + f_gate->get_nz_idx(i, j) * kTu + k];
158+ u_val[k] = f_weight_pruned[i * kInputSize + j * kTu + k];
157159 }
158160 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
159161 for (int k = 0 ; k < testu::params::Tu; ++k) {
160- u_val[k] = c_weight[i * kInputSize + c_gate->get_nz_idx (i, j) * kTu + k];
162+ // u_val[k] = c_weight[i * kInputSize + c_gate->get_nz_idx(i, j) * kTu + k];
163+ u_val[k] = c_weight_pruned[i * kInputSize + j * kTu + k];
161164 }
162165 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
163166 for (int k = 0 ; k < testu::params::Tu; ++k) {
164- u_val[k] = o_weight[i * kInputSize + o_gate->get_nz_idx (i, j) * kTu + k];
167+ // u_val[k] = o_weight[i * kInputSize + o_gate->get_nz_idx(i, j) * kTu + k];
168+ u_val[k] = o_weight_pruned[i * kInputSize + j * kTu + k];
165169 }
166170 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
167171 }
168172 }
173+
174+ std::cout << " [INFO] Sending nzu." << std::endl;
175+ for (int i = 0 ; i < num_refinements[kNumActiveInputs - 1 ]; ++i) {
176+ for (int j = 0 ; j < kNumTilesU - kZTu ; ++j) {
177+ const int bits = testu::params::NumTuBits;
178+ IndexType nzu_val;
179+ nzu_val.range (1 * bits - 1 , 0 * bits) = i_gate->get_nz_idx (i, j);
180+ nzu_val.range (2 * bits - 1 , 1 * bits) = f_gate->get_nz_idx (i, j);
181+ nzu_val.range (3 * bits - 1 , 2 * bits) = c_gate->get_nz_idx (i, j);
182+ nzu_val.range (4 * bits - 1 , 3 * bits) = o_gate->get_nz_idx (i, j);
183+ // std::cout << i_gate->get_nz_idx(i, j) << std::endl;
184+ unz_idx_interface.Push <IndexType>(nzu_val);
185+ }
186+ }
187+
169188 std::cout << " [INFO] Starting HlsKernelU." << std::endl;
170189 // HlsKernelU(kNumActiveInputs, kInputSize, refinements_tmp, false, x_axis, u_axis, xu_axis);
171- const int ztu = 0 ; // kZTu;
172- HlsKernelU_Pruned (kNumActiveInputs , kInputSize , num_refinements, ztu, unz_idx_axis, x_axis, u_axis, xu_axis);
190+ HlsKernelU_Pruned (kNumActiveInputs , kInputSize , num_refinements, kZTu , unz_idx_axis, x_axis, u_axis, xu_axis);
173191
174192 testu::params::VectG_Type xu_g_val;
175193 int total_cnt = 0 ;
0 commit comments