Skip to content

Commit e38dc51

Browse files
committed
Test Pruned Kernel-U: fixed sending pruned U weight. Now it works.
1 parent a7985be commit e38dc51

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/testbenches/test_u_kernel_pruned.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,22 @@ int main(int argc, char const *argv[]) {
2626
int R_tmp = testu::params::R - 2 * (testu::params::N - i - 1);
2727
num_refinements[i] = R_tmp > 0 ? R_tmp : 1;
2828
}
29+
30+
const int kNumActiveInputs = 1; // testu::params::N;
31+
const int kInputSize_tmp = testu::params::I / 16;
32+
const int kInputSize = (kInputSize_tmp > testu::params::I) ? testu::params::I : kInputSize_tmp;
33+
const int kNumTilesU = kInputSize / testu::params::Tu;
2934
const int kN = testu::params::N;
3035
const int kR = testu::params::R;
3136
const int kI = testu::params::I;
3237
const int kH = testu::params::H;
3338
const int kTu = testu::params::Tu;
3439
const int kNTu = testu::params::MaxNumTu;
35-
const int kZTu = 8; // testu::params::ZTu;
40+
const int kZTu_tmp = 10;
41+
const int kZTu = kZTu_tmp >= kNumTilesU ? 0 : kZTu_tmp; // testu::params::ZTu;
3642
const int kNTv = testu::params::MaxNumTv;
3743
const int kZTv = testu::params::ZTv;
3844

39-
const int kNumActiveInputs = 1; // testu::params::N;
40-
const int kInputSize_tmp = testu::params::I / 16;
41-
const int kInputSize = (kInputSize_tmp > testu::params::I) ? testu::params::I : kInputSize_tmp;
42-
const int kNumTilesU = kInputSize / testu::params::Tu;
43-
4445
typedef typename testu::params::ActivationD ActivationType;
4546
typedef ap_uint<testu::params::NumGTuBitsAligned> IndexType;
4647

@@ -149,23 +150,19 @@ int main(int argc, char const *argv[]) {
149150
for (int j = 0; j < kNumTilesU - kZTu; ++j) {
150151
VectTuAct_Type u_val;
151152
for (int k = 0; k < testu::params::Tu; ++k) {
152-
// u_val[k] = i_weight[i * kInputSize + i_gate->get_nz_idx(i, j) * kTu + k];
153-
u_val[k] = i_weight_pruned[i * kInputSize + j * kTu + k];
153+
u_val[k] = i_weight[i * kInputSize + i_gate->get_nz_idx(i, j) * kTu + k];
154154
}
155155
u_interface.PushVector<ActivationType, testu::params::Tu>(u_val);
156156
for (int k = 0; k < testu::params::Tu; ++k) {
157-
// u_val[k] = f_weight[i * kInputSize + f_gate->get_nz_idx(i, j) * kTu + k];
158-
u_val[k] = f_weight_pruned[i * kInputSize + j * kTu + k];
157+
u_val[k] = f_weight[i * kInputSize + f_gate->get_nz_idx(i, j) * kTu + k];
159158
}
160159
u_interface.PushVector<ActivationType, testu::params::Tu>(u_val);
161160
for (int k = 0; k < testu::params::Tu; ++k) {
162-
// u_val[k] = c_weight[i * kInputSize + c_gate->get_nz_idx(i, j) * kTu + k];
163-
u_val[k] = c_weight_pruned[i * kInputSize + j * kTu + k];
161+
u_val[k] = c_weight[i * kInputSize + c_gate->get_nz_idx(i, j) * kTu + k];
164162
}
165163
u_interface.PushVector<ActivationType, testu::params::Tu>(u_val);
166164
for (int k = 0; k < testu::params::Tu; ++k) {
167-
// u_val[k] = o_weight[i * kInputSize + o_gate->get_nz_idx(i, j) * kTu + k];
168-
u_val[k] = o_weight_pruned[i * kInputSize + j * kTu + k];
165+
u_val[k] = o_weight[i * kInputSize + o_gate->get_nz_idx(i, j) * kTu + k];
169166
}
170167
u_interface.PushVector<ActivationType, testu::params::Tu>(u_val);
171168
}

0 commit comments

Comments
 (0)