Skip to content

Commit 32b21d4

Browse files
committed
KernelU_Pruned: initial work on pruned version of Kernel-U.
1 parent f652133 commit 32b21d4

File tree

8 files changed

+453
-18
lines changed

8 files changed

+453
-18
lines changed

include/kernel/u_kernel.h

Lines changed: 183 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ void KernelU(const int num_active_inputs,
199199
const int input_size,
200200
const int num_refinements[params::N],
201201
const bool pad_output,
202-
// hls::stream<typename params::IndexU_Type>& z_idx_port,
203202
hls::stream<typename params::VectTuAxiPacketType>& x_port,
204203
hls::stream<typename params::VectTuAxiPacketType>& u_port,
205204
hls::stream<typename WrapperAxisG::PacketType>& xu_port) {
@@ -210,13 +209,14 @@ void KernelU(const int num_active_inputs,
210209
#pragma HLS STABLE variable=u_port
211210
#pragma HLS STABLE variable=xu_port
212211
typedef typename params::ActivationD ActivationType;
213-
const int kNumTilesU = input_size / params::Tu;
214-
const int kMaxNumTilesU = params::I / params::Tu;
215-
const int kStreamDepth_X = 2 + kMaxNumTilesU * params::N;
216-
const int kStreamDepth_U = 8 + kMaxNumTilesU * params::N;
217-
const int kStreamDepth_XU = 2 + params::G;
218-
assert(num_active_inputs <= params::N);
212+
const unsigned int kNumTilesU = input_size / params::Tu;
213+
const unsigned int kMaxNumTilesU = params::I / params::Tu;
214+
const unsigned int kStreamDepth_X = 2 + kMaxNumTilesU * params::N;
215+
const unsigned int kStreamDepth_U = 8 + kMaxNumTilesU * params::N;
216+
const unsigned int kStreamDepth_XU = 2 + params::G;
219217
assert(num_active_inputs > 0);
218+
assert(kNumTilesU > 0);
219+
assert(num_active_inputs <= params::N);
220220
assert(params::I % params::Tu == 0);
221221
assert(input_size % params::Tu == 0);
222222
assert(input_size <= params::I);
@@ -344,6 +344,180 @@ void KernelU(const int num_active_inputs,
344344
}
345345
}
346346
}
347+
348+
349+
template <
350+
typename params,
351+
typename WrapperAxisG = svd::AxiStreamPort<params::VectG_AxiWidth>
352+
>
353+
void KernelU_Pruned(const int num_active_inputs,
354+
const int input_size,
355+
const int num_refinements[params::N],
356+
const int num_zero_tiles_u,
357+
hls::stream<typename params::VectGZTuAxiPacketType>& uz_idx_port,
358+
hls::stream<typename params::VectTuAxiPacketType>& x_port,
359+
hls::stream<typename params::VectTuAxiPacketType>& u_port,
360+
hls::stream<typename WrapperAxisG::PacketType>& xu_port) {
361+
#pragma HLS TOP name=KernelU
362+
#pragma HLS DATAFLOW
363+
#pragma HLS INLINE
364+
#pragma HLS STABLE variable=x_port
365+
#pragma HLS STABLE variable=u_port
366+
#pragma HLS STABLE variable=xu_port
367+
typedef typename params::ActivationD ActivationType;
368+
const unsigned int kNumTilesU = input_size / params::Tu;
369+
const unsigned int kMaxNumTilesU = params::I / params::Tu;
370+
const unsigned int kStreamDepth_X = 2 + kMaxNumTilesU * params::N;
371+
const unsigned int kStreamDepth_U = 8 + kMaxNumTilesU * params::N;
372+
const unsigned int kStreamDepth_XU = 2 + params::G;
373+
assert(num_active_inputs > 0);
374+
assert(kNumTilesU > 0);
375+
assert(num_active_inputs <= params::N);
376+
assert(params::I % params::Tu == 0);
377+
assert(input_size % params::Tu == 0);
378+
assert(input_size <= params::I);
379+
assert(kNumTilesU <= kMaxNumTilesU);
380+
auto uz_axis = svd::AxiStreamPort<params::NumGTuBitsAligned>(uz_idx_port);
381+
auto x_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(x_port);
382+
auto u_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(u_port);
383+
auto xu_axis = svd::AxiStreamInterface<WrapperAxisG>(xu_port);
384+
hls::stream<typename params::VectTuType> x_stream[params::G];
385+
hls::stream<typename params::VectTuType> u_streams[params::G];
386+
hls::stream<ActivationType> xu_streams[params::G];
387+
ActivationType x_buffer[params::N][params::Tu][kMaxNumTilesU];
388+
#pragma HLS STREAM variable=x_stream depth=kStreamDepth_X
389+
#pragma HLS STREAM variable=u_streams depth=kStreamDepth_U
390+
#pragma HLS STREAM variable=xu_streams depth=kStreamDepth_XU
391+
#pragma HLS ARRAY_PARTITION variable=u_streams complete dim=1
392+
#pragma HLS ARRAY_PARTITION variable=x_buffer complete dim=1
393+
#pragma HLS ARRAY_PARTITION variable=x_buffer complete dim=2
394+
#pragma HLS BIND_STORAGE variable=x_buffer type=ram_t2p impl=bram latency=1
395+
/*
396+
* Ideally, if the Rs are ordered, it would be: R0 * N + (R1-R0) * (N-1) +
397+
* (R2-R1) * (N-2)
398+
*
399+
* Imagine we have: R0 = 2, R1 = 3, R2 = 6
400+
*
401+
* This means:
402+
* - till refinement 2 we have input 0 to process
403+
* - till refinement 3 we have input 1 to process
404+
* - till refinement 6 we have input 2 to process
405+
*
406+
* So it would become:
407+
*
408+
* R_total = 2 * 3 + (3-2) * (3-1) + (6-3) * (3-2)
409+
*/
410+
// ===========================================================================
411+
// TODO: Same as non-pruned version -> wrap into a function
412+
// ===========================================================================
413+
int R_max = num_refinements[0];
414+
int R_total = num_refinements[0] * num_active_inputs; // Total elements.
415+
Get_Total_R:
416+
for (int i = 1; i < num_active_inputs; ++i) {
417+
#pragma HLS PIPELINE II=1 style=frp
418+
if (num_refinements[i] > R_max) {
419+
R_max = num_refinements[i];
420+
}
421+
assert(num_refinements[i] >= num_refinements[i - 1]);
422+
R_total += (num_refinements[i] - num_refinements[i - 1]) * (num_active_inputs - i);
423+
}
424+
425+
// Added
426+
X_DAM_in:
427+
for (int i = 0; i < num_active_inputs; ++i) {
428+
for (int j = 0; j < kNumTilesU; ++j) {
429+
#pragma HLS LOOP_FLATTEN
430+
#pragma HLS PIPELINE II=1 style=frp
431+
auto x_val = x_axis.template PopVector<ActivationType, params::Tu>();
432+
for (int k = 0; k < params::Tu; ++k) {
433+
x_buffer[i][k][j] = x_val[k];
434+
}
435+
}
436+
}
437+
438+
// Changed
439+
int R_prev = 0;
440+
X_DMA_dispatcher:
441+
for (int ii = 0; ii < num_active_inputs; ++ii) {
442+
for (int i = 0; i < num_refinements[ii] - R_prev; ++i) {
443+
assert(num_refinements[ii] - R_prev >= 1);
444+
for (int j = 0; j < kNumTilesU; ++j) {
445+
// Read z_idx
446+
auto z_idx = uz_axis.template PopVector<ActivationType, params::G>();
447+
for (int k = 0; k < num_active_inputs - ii; ++k) {
448+
#pragma HLS PIPELINE II=1 style=frp
449+
assert(num_active_inputs - ii >= 1);
450+
assert(k + ii < params::N);
451+
for (int kk = 0; kk < params::G; ++kk) {
452+
typename params::VectTuType x_val;
453+
for (int jj = 0; jj < params::Tu; ++jj) {
454+
x_val[jj] = x_buffer[k + ii][jj][z_idx[kk]];
455+
}
456+
x_stream[kk] << x_val;
457+
}
458+
}
459+
}
460+
}
461+
R_prev = num_refinements[ii];
462+
}
463+
464+
// ===========================================================================
465+
// TODO: Same as non-pruned version -> wrap into a function
466+
// ===========================================================================
467+
U_DMA:
468+
for (int i = 0; i < R_max; ++i) {
469+
#pragma HLS LOOP_TRIPCOUNT min=params::R max=params::R
470+
for (int j = 0; j < kNumTilesU; ++j) {
471+
for (int k = 0; k < params::G; ++k) {
472+
auto u_val = u_axis.template PopVector<ActivationType, params::Tu>();
473+
for (int ii = 0; ii < num_active_inputs; ++ii) {
474+
#pragma HLS PIPELINE II=1 style=frp
475+
if (i < num_refinements[ii]) {
476+
u_streams[k] << u_val;
477+
}
478+
}
479+
}
480+
}
481+
}
482+
483+
// Changed
484+
U_Kernel:
485+
for (int i = 0; i < R_total; ++i) {
486+
for (int j = 0; j < kNumTilesU; ++j) {
487+
#pragma HLS PIPELINE II=1 style=frp
488+
for (int k = 0; k < params::G; ++k) {
489+
xu_streams[k] << hlsutils::adder_tree<ActivationType, params::Tu>(
490+
x_stream[k].read() * u_streams[k].read());
491+
}
492+
}
493+
}
494+
495+
// ===========================================================================
496+
// TODO: Same as non-pruned version -> wrap into a function
497+
// ===========================================================================
498+
int iter_cnt = 0;
499+
XU_DMA:
500+
for (int i = 0; i < R_max; ++i) {
501+
typename params::VectG_Type xu_out[params::N] = {typename params::VectG_Type(0)};
502+
#pragma HLS ARRAY_PARTITION variable=xu_out complete dim=1
503+
for (int j = 0; j < kNumTilesU; ++j) {
504+
for (int k = 0; k < num_active_inputs; ++k) {
505+
#pragma HLS PIPELINE II=1 style=frp
506+
for (int ii = 0; ii < params::G; ++ii) {
507+
if (i < num_refinements[k]) {
508+
xu_out[k][ii] += xu_streams[ii].read();
509+
#pragma HLS BIND_OP variable=xu_out[k][ii] op=add impl=dsp
510+
}
511+
}
512+
if (i < num_refinements[k] && j == kNumTilesU - 1) {
513+
const bool kIsLast = iter_cnt == R_total - 1;
514+
xu_axis.template PushVector<ActivationType, params::G>(xu_out[k], kIsLast);
515+
++iter_cnt;
516+
}
517+
}
518+
}
519+
}
520+
}
347521
#endif // end __VITIS_HLS__
348522

349523
} // svd
@@ -404,6 +578,8 @@ void HlsKernelU(const int num_active_inputs,
404578
const int input_size,
405579
const int num_refinements[testu::params::N],
406580
const bool pad_output,
581+
// const int num_zero_tiles_u,
582+
// hls::stream<ap_uint<testu::NumTuBits> >& uz_idx_port,
407583
hls::stream<typename testu::params::VectTuAxiPacketType>& x_port,
408584
hls::stream<typename testu::params::VectTuAxiPacketType>& u_port,
409585
hls::stream<typename testu::params::VectG_AxiPacketType>& xu_port);

include/layers/lstm/hls/lstm_svd.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,4 +407,32 @@ void HlsWrapperLstmSvd(
407407
typename svd::lstm_params::ActivationD* h_curr,
408408
typename svd::lstm_params::ActivationD* c_curr);
409409

410+
extern "C" void C_WrapperLstmSvd(
411+
const int num_timesteps,
412+
const int num_active_inputs,
413+
const int input_size,
414+
const int output_size,
415+
const int num_refinements[svd::lstm_params::N],
416+
const int num_zero_tiles_u,
417+
const int num_zero_tiles_v,
418+
// Current Gates
419+
const float* x_in,
420+
const float* u_cur_in,
421+
const float* s_cur_in,
422+
const float* v_cur_in,
423+
const int* uz_idx_cur_in,
424+
const int* vz_idx_cur_in,
425+
// Recurrent Gates
426+
const float* h_in,
427+
const float* u_rec_in,
428+
const float* s_rec_in,
429+
const float* v_rec_in,
430+
const int* uz_idx_rec_in,
431+
const int* vz_idx_rec_in,
432+
// Non-Linearities
433+
const float* bias_in,
434+
const float* c_prev_in,
435+
float* h_curr_in,
436+
float* c_curr_in);
437+
410438
#endif // end LSTM_HLS_LSTM_SVD_H_

include/svd_params.h

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ template <int Ni, int Ii, int Hi, int Ri, int Tui, int Tvi, int ZTui = 0,
3737
typename WeightD_tp = ap_fixed<8, 3>,
3838
typename AccumulationD_tp = ap_fixed<16, 3> >
3939
struct SvdParameters {
40+
static_assert(Ni > 0, "ERROR. Found negative value: N <= 0");
41+
static_assert(Ii > 0, "ERROR. Found negative value: I <= 0");
42+
static_assert(Hi > 0, "ERROR. Found negative value: H <= 0");
43+
static_assert(Tui > 0, "ERROR. Found negative value: Tu <= 0");
44+
static_assert(Tvi > 0, "ERROR. Found negative value: Tv <= 0");
4045
static const int N = Ni;
4146
static const int I = Ii;
4247
static const int H = Hi;
@@ -51,24 +56,28 @@ struct SvdParameters {
5156
static const int PeU = MaxNumTu - ZTu;
5257
static const int PeV = H / MaxNumTv;
5358
private:
54-
static const int TuBits_tmp = hlsutils::log2<MaxNumTu>::value;
55-
static const int TvBits_tmp = hlsutils::log2<MaxNumTv>::value;
59+
static const int NumTuBits_tmp = hlsutils::log2<MaxNumTu>::value;
60+
static const int NumTvBits_tmp = hlsutils::log2<MaxNumTv>::value;
5661
public:
57-
static const int TuBits = TuBits_tmp > 0 ? TuBits_tmp : 1;
58-
static const int TvBits = TvBits_tmp > 0 ? TvBits_tmp : 1;
59-
typedef ap_uint<MaxNumTu> IndexU_Type;
60-
typedef ap_uint<MaxNumTv> IndexV_Type;
62+
static const int NumTuBits = NumTuBits_tmp > 0 ? NumTuBits_tmp : 1;
63+
static const int NumTvBits = NumTvBits_tmp > 0 ? NumTvBits_tmp : 1;
64+
static const int NumTuBitsAligned = (NumTuBits + 7) & (-8); // align to 8bit
65+
static const int NumTvBitsAligned = (NumTvBits + 7) & (-8); // align to 8bit
66+
static const int NumGTuBitsAligned = (NumTuBits * G + 7) & (-8); // align to 8bit
67+
static const int NumGTvBitsAligned = (NumTvBits * G + 7) & (-8); // align to 8bit
68+
typedef ap_uint<NumTuBits> IndexU_Type; // deprecated
69+
typedef ap_uint<NumTvBits> IndexV_Type; // deprecated
6170
typedef ap_uint<MaxNumTu> UnzD;
6271
typedef ap_uint<MaxNumTv> VnzD;
63-
typedef ap_uint<TuBits> UnzIdxD;
64-
typedef ap_uint<TvBits> VnzIdxD;
72+
typedef ap_uint<NumTuBits> UnzIdxD;
73+
typedef ap_uint<NumTvBits> VnzIdxD;
6574
typedef ActivationD_tp ActivationD;
6675
typedef WeightD_tp WeightD;
6776
typedef AccumulationD_tp AccumulationD;
6877
typedef hls::stream<UnzD> UnzS;
6978
typedef hls::stream<VnzD> VnzS;
70-
typedef hls::stream<ap_uint<TuBits> > UnzIdxS;
71-
typedef hls::stream<ap_uint<TvBits> > VnzIdxS;
79+
typedef hls::stream<ap_uint<NumTuBits> > UnzIdxS;
80+
typedef hls::stream<ap_uint<NumTvBits> > VnzIdxS;
7281
typedef hls::stream<ActivationD> ActivationS;
7382
typedef hls::stream<WeightD> WeightS;
7483
typedef hls::stream<AccumulationD> AccumulationS;
@@ -87,6 +96,10 @@ struct SvdParameters {
8796
static const int VectG_AxiWidth = ActivationWidth * G;
8897
static const int VectGN_AxiWidth = ActivationWidth * G * N;
8998
static const int VectGTvAxiWidth = ActivationWidth * G * Tv;
99+
100+
typedef typename svd::AxiStreamPort<NumGTuBitsAligned>::PacketType VectGZTuAxiPacketType;
101+
typedef typename svd::AxiStreamPort<NumGTvBitsAligned>::PacketType VectGZTvAxiPacketType;
102+
90103
typedef typename svd::AxiStreamPort<VectTuAxiWidth>::PacketType VectTuAxiPacketType;
91104
typedef typename svd::AxiStreamPort<VectTvAxiWidth>::PacketType VectTvAxiPacketType;
92105
typedef typename svd::AxiStreamPort<VectN_AxiWidth>::PacketType VectN_AxiPacketType;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#ifndef TESTBENCHES_TEST_U_KERNEL_H_
2+
#define TESTBENCHES_TEST_U_KERNEL_H_
3+
4+
#include "kernel/u_kernel.h"
5+
#include "hls_utils/hls_debugging.h"
6+
7+
#endif // end TESTBENCHES_TEST_U_KERNEL_H_

src/layers/lstm/hls/lstm_svd.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,6 @@ void HlsLstmSvd(const int num_active_inputs,
416416
const int input_size,
417417
const int output_size,
418418
const int num_refinements[svd::lstm_params::N],
419-
// const hls::vector<int, svd::lstm_params::N> num_refinements,
420419
// Current Gates
421420
hls::stream<typename svd::lstm_params::VectTuAxiPacketType>& x_port,
422421
hls::stream<typename svd::lstm_params::VectTuAxiPacketType>& u_cur_port,
@@ -557,16 +556,22 @@ extern "C" void C_WrapperLstmSvd(
557556
const int input_size,
558557
const int output_size,
559558
const int num_refinements[svd::lstm_params::N],
559+
const int num_zero_tiles_u,
560+
const int num_zero_tiles_v,
560561
// Current Gates
561562
const float* x_in,
562563
const float* u_cur_in,
563564
const float* s_cur_in,
564565
const float* v_cur_in,
566+
const int* uz_idx_cur_in,
567+
const int* vz_idx_cur_in,
565568
// Recurrent Gates
566569
const float* h_in,
567570
const float* u_rec_in,
568571
const float* s_rec_in,
569572
const float* v_rec_in,
573+
const int* uz_idx_rec_in,
574+
const int* vz_idx_rec_in,
570575
// Non-Linearities
571576
const float* bias_in,
572577
const float* c_prev_in,

src/svd.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ int main(int argc, char const *argv[]) {
1515

1616
const bool kTestSoftwareAccelerator = false;
1717
const int kN = 2;
18+
const int kNumActiveInputs = 1;
1819
const int kR = svd::lstm_params::R;
1920
const int kI = svd::lstm_params::I;
2021
const int kH = svd::lstm_params::H;
@@ -178,6 +179,36 @@ int main(int argc, char const *argv[]) {
178179
storage.get_h(j));
179180
}
180181
}
182+
183+
// int num_refinements[kN] = {kR};
184+
// C_WrapperLstmSvd(
185+
// NUM_TIMESTEPS,
186+
// kNumActiveInputs,
187+
// kI,
188+
// kH,
189+
// num_refinements,
190+
// kZTu,
191+
// kZTv,
192+
// // Current Gates
193+
// x_in,
194+
// u_cur_in,
195+
// s_cur_in,
196+
// v_cur_in,
197+
// uz_idx_cur_in,
198+
// vz_idx_cur_in,
199+
// // Recurrent Gates
200+
// h_in,
201+
// u_rec_in,
202+
// s_rec_in,
203+
// v_rec_in,
204+
// uz_idx_rec_in,
205+
// vz_idx_rec_in,
206+
// // Non-Linearities
207+
// bias_in,
208+
// c_prev_in,
209+
// h_curr_in,
210+
// c_curr_in);
211+
181212
storage.ResetLstmOutputs();
182213
std::cout << "Cleaning up." << std::endl;
183214
delete[] h_prev_hls;

0 commit comments

Comments
 (0)