@@ -26,21 +26,22 @@ int main(int argc, char const *argv[]) {
2626 int R_tmp = testu::params::R - 2 * (testu::params::N - i - 1 );
2727 num_refinements[i] = R_tmp > 0 ? R_tmp : 1 ;
2828 }
29+
30+ const int kNumActiveInputs = 1 ; // testu::params::N;
31+ const int kInputSize_tmp = testu::params::I / 16 ;
32+ const int kInputSize = (kInputSize_tmp > testu::params::I) ? testu::params::I : kInputSize_tmp ;
33+ const int kNumTilesU = kInputSize / testu::params::Tu;
2934 const int kN = testu::params::N;
3035 const int kR = testu::params::R;
3136 const int kI = testu::params::I;
3237 const int kH = testu::params::H;
3338 const int kTu = testu::params::Tu;
3439 const int kNTu = testu::params::MaxNumTu;
35- const int kZTu = 8 ; // testu::params::ZTu;
40+ const int kZTu_tmp = 10 ;
41+ const int kZTu = kZTu_tmp >= kNumTilesU ? 0 : kZTu_tmp ; // testu::params::ZTu;
3642 const int kNTv = testu::params::MaxNumTv;
3743 const int kZTv = testu::params::ZTv;
3844
39- const int kNumActiveInputs = 1 ; // testu::params::N;
40- const int kInputSize_tmp = testu::params::I / 16 ;
41- const int kInputSize = (kInputSize_tmp > testu::params::I) ? testu::params::I : kInputSize_tmp ;
42- const int kNumTilesU = kInputSize / testu::params::Tu;
43-
4445 typedef typename testu::params::ActivationD ActivationType;
4546 typedef ap_uint<testu::params::NumGTuBitsAligned> IndexType;
4647
@@ -149,23 +150,19 @@ int main(int argc, char const *argv[]) {
149150 for (int j = 0 ; j < kNumTilesU - kZTu ; ++j) {
150151 VectTuAct_Type u_val;
151152 for (int k = 0 ; k < testu::params::Tu; ++k) {
152- // u_val[k] = i_weight[i * kInputSize + i_gate->get_nz_idx(i, j) * kTu + k];
153- u_val[k] = i_weight_pruned[i * kInputSize + j * kTu + k];
153+ u_val[k] = i_weight[i * kInputSize + i_gate->get_nz_idx (i, j) * kTu + k];
154154 }
155155 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
156156 for (int k = 0 ; k < testu::params::Tu; ++k) {
157- // u_val[k] = f_weight[i * kInputSize + f_gate->get_nz_idx(i, j) * kTu + k];
158- u_val[k] = f_weight_pruned[i * kInputSize + j * kTu + k];
157+ u_val[k] = f_weight[i * kInputSize + f_gate->get_nz_idx (i, j) * kTu + k];
159158 }
160159 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
161160 for (int k = 0 ; k < testu::params::Tu; ++k) {
162- // u_val[k] = c_weight[i * kInputSize + c_gate->get_nz_idx(i, j) * kTu + k];
163- u_val[k] = c_weight_pruned[i * kInputSize + j * kTu + k];
161+ u_val[k] = c_weight[i * kInputSize + c_gate->get_nz_idx (i, j) * kTu + k];
164162 }
165163 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
166164 for (int k = 0 ; k < testu::params::Tu; ++k) {
167- // u_val[k] = o_weight[i * kInputSize + o_gate->get_nz_idx(i, j) * kTu + k];
168- u_val[k] = o_weight_pruned[i * kInputSize + j * kTu + k];
165+ u_val[k] = o_weight[i * kInputSize + o_gate->get_nz_idx (i, j) * kTu + k];
169166 }
170167 u_interface.PushVector <ActivationType, testu::params::Tu>(u_val);
171168 }
0 commit comments