Skip to content

Commit 6e4f2da

Browse files
committed
Pruned versions of U and V kernels: they are backward compatible when num_zero_tiles == 0.
1 parent 32b21d4 commit 6e4f2da

File tree

13 files changed

+418
-119
lines changed

13 files changed

+418
-119
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ add_compile_definitions(HIDDEN_SIZE=512)
4646
add_compile_definitions(NUM_GATES=4)
4747
add_compile_definitions(NUM_SAMPLES=2)
4848
add_compile_definitions(NUM_TILES_U=8)
49-
add_compile_definitions(NUM_ZERO_TILES_U=1)
49+
add_compile_definitions(NUM_ZERO_TILES_U=0)
5050
add_compile_definitions(NUM_TILES_V=8)
51-
add_compile_definitions(NUM_ZERO_TILES_V=1)
51+
add_compile_definitions(NUM_ZERO_TILES_V=0)
5252
add_compile_definitions(NUM_TIMESTEPS=28)
5353
add_compile_definitions(FIX_WIDTH=16)
5454
add_compile_definitions(FIX_FRACT_WIDTH=5)

include/kernel/u_kernel.h

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,14 @@ void KernelU_Pruned(const int num_active_inputs,
354354
const int input_size,
355355
const int num_refinements[params::N],
356356
const int num_zero_tiles_u,
357-
hls::stream<typename params::VectGZTuAxiPacketType>& uz_idx_port,
357+
hls::stream<typename params::VectGZTuAxiPacketType>& unz_idx_port,
358358
hls::stream<typename params::VectTuAxiPacketType>& x_port,
359359
hls::stream<typename params::VectTuAxiPacketType>& u_port,
360360
hls::stream<typename WrapperAxisG::PacketType>& xu_port) {
361361
#pragma HLS TOP name=KernelU
362362
#pragma HLS DATAFLOW
363-
#pragma HLS INLINE
363+
// #pragma HLS INLINE
364+
#pragma HLS STABLE variable=unz_idx_port
364365
#pragma HLS STABLE variable=x_port
365366
#pragma HLS STABLE variable=u_port
366367
#pragma HLS STABLE variable=xu_port
@@ -377,7 +378,7 @@ void KernelU_Pruned(const int num_active_inputs,
377378
assert(input_size % params::Tu == 0);
378379
assert(input_size <= params::I);
379380
assert(kNumTilesU <= kMaxNumTilesU);
380-
auto uz_axis = svd::AxiStreamPort<params::NumGTuBitsAligned>(uz_idx_port);
381+
auto uz_axis = svd::AxiStreamPort<params::NumGTuBitsAligned>(unz_idx_port);
381382
auto x_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(x_port);
382383
auto u_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(u_port);
383384
auto xu_axis = svd::AxiStreamInterface<WrapperAxisG>(xu_port);
@@ -407,19 +408,35 @@ void KernelU_Pruned(const int num_active_inputs,
407408
*
408409
* R_total = 2 * 3 + (3-2) * (3-1) + (6-3) * (3-2)
409410
*/
411+
int num_refinements_init[params::N];
412+
int num_refinements_x_dma[params::N];
413+
int num_refinements_u_dma[params::N];
414+
int num_refinements_xu_dma[params::N];
415+
#pragma HLS ARRAY_PARTITION variable=num_refinements_init complete
416+
#pragma HLS ARRAY_PARTITION variable=num_refinements_x_dma complete
417+
#pragma HLS ARRAY_PARTITION variable=num_refinements_u_dma complete
418+
#pragma HLS ARRAY_PARTITION variable=num_refinements_xu_dma complete
419+
for (int i = 0; i < params::N; ++i) {
420+
#pragma HLS UNROLL
421+
num_refinements_init[i] = num_refinements[i];
422+
num_refinements_x_dma[i] = num_refinements[i];
423+
num_refinements_u_dma[i] = num_refinements[i];
424+
num_refinements_xu_dma[i] = num_refinements[i];
425+
}
426+
410427
// ===========================================================================
411-
// TODO: Same as non-pruned version -> wrap into a function
428+
// TODO: Same as non-pruned version -> wrap into a function (be careful to NTu-ZTu)
412429
// ===========================================================================
413-
int R_max = num_refinements[0];
414-
int R_total = num_refinements[0] * num_active_inputs; // Total elements.
430+
int R_max = num_refinements_init[0];
431+
int R_total = num_refinements_init[0] * num_active_inputs; // Total elements.
415432
Get_Total_R:
416433
for (int i = 1; i < num_active_inputs; ++i) {
417434
#pragma HLS PIPELINE II=1 style=frp
418-
if (num_refinements[i] > R_max) {
419-
R_max = num_refinements[i];
435+
if (num_refinements_init[i] > R_max) {
436+
R_max = num_refinements_init[i];
420437
}
421-
assert(num_refinements[i] >= num_refinements[i - 1]);
422-
R_total += (num_refinements[i] - num_refinements[i - 1]) * (num_active_inputs - i);
438+
assert(num_refinements_init[i] >= num_refinements_init[i - 1]);
439+
R_total += (num_refinements_init[i] - num_refinements_init[i - 1]) * (num_active_inputs - i);
423440
}
424441

425442
// Added
@@ -435,44 +452,61 @@ void KernelU_Pruned(const int num_active_inputs,
435452
}
436453
}
437454

455+
typedef ap_uint<params::NumGTuBitsAligned> ZIndexType;
456+
auto get_idx = [](const ZIndexType nz_idx, const int i) {
457+
const int kHi = (i + 1) * params::NumTuBits - 1;
458+
const int kLo = i * params::NumTuBits;
459+
return nz_idx.range(kHi, kLo).to_int();
460+
};
461+
auto set_nz_idx = [](const int x) {
462+
#pragma HLS PIPELINE II=1 style=frp
463+
ZIndexType nz_idx;
464+
const auto tmp = ap_uint<params::NumTuBits>(x);
465+
for (int i = 0; i < params::G; ++i) {
466+
const int kHi = (i + 1) * params::NumTuBits - 1;
467+
const int kLo = i * params::NumTuBits;
468+
nz_idx.range(kHi, kLo) = tmp.range();
469+
}
470+
return nz_idx;
471+
};
438472
// Changed
439473
int R_prev = 0;
440-
X_DMA_dispatcher:
474+
X_DMA_Dispatcher:
441475
for (int ii = 0; ii < num_active_inputs; ++ii) {
442-
for (int i = 0; i < num_refinements[ii] - R_prev; ++i) {
443-
assert(num_refinements[ii] - R_prev >= 1);
444-
for (int j = 0; j < kNumTilesU; ++j) {
445-
// Read z_idx
446-
auto z_idx = uz_axis.template PopVector<ActivationType, params::G>();
476+
for (int i = 0; i < num_refinements_x_dma[ii] - R_prev; ++i) {
477+
assert(num_refinements_x_dma[ii] - R_prev >= 1);
478+
for (int j = 0; j < kNumTilesU - num_zero_tiles_u; ++j) {
479+
auto nz_idx = num_zero_tiles_u > 0 ? uz_axis.template Pop<ZIndexType>() : set_nz_idx(j);
447480
for (int k = 0; k < num_active_inputs - ii; ++k) {
448-
#pragma HLS PIPELINE II=1 style=frp
449481
assert(num_active_inputs - ii >= 1);
450482
assert(k + ii < params::N);
451483
for (int kk = 0; kk < params::G; ++kk) {
484+
#pragma HLS LOOP_FLATTEN
485+
#pragma HLS PIPELINE II=1 style=frp
452486
typename params::VectTuType x_val;
453487
for (int jj = 0; jj < params::Tu; ++jj) {
454-
x_val[jj] = x_buffer[k + ii][jj][z_idx[kk]];
488+
x_val[jj] = x_buffer[k + ii][jj][get_idx(nz_idx, kk)];
455489
}
456490
x_stream[kk] << x_val;
457491
}
458492
}
459493
}
460494
}
461-
R_prev = num_refinements[ii];
495+
R_prev = num_refinements_x_dma[ii];
462496
}
463497

464498
// ===========================================================================
465-
// TODO: Same as non-pruned version -> wrap into a function
499+
// TODO: Same as non-pruned version -> wrap into a function (be careful to NTv-ZTv)
466500
// ===========================================================================
467501
U_DMA:
468502
for (int i = 0; i < R_max; ++i) {
469503
#pragma HLS LOOP_TRIPCOUNT min=params::R max=params::R
470-
for (int j = 0; j < kNumTilesU; ++j) {
504+
for (int j = 0; j < kNumTilesU - num_zero_tiles_u; ++j) {
471505
for (int k = 0; k < params::G; ++k) {
472506
auto u_val = u_axis.template PopVector<ActivationType, params::Tu>();
473507
for (int ii = 0; ii < num_active_inputs; ++ii) {
474508
#pragma HLS PIPELINE II=1 style=frp
475-
if (i < num_refinements[ii]) {
509+
if (i < num_refinements_u_dma[ii]) {
476510
u_streams[k] << u_val;
477511
}
478512
}
@@ -483,7 +517,7 @@ void KernelU_Pruned(const int num_active_inputs,
483517
// Changed
484518
U_Kernel:
485519
for (int i = 0; i < R_total; ++i) {
486-
for (int j = 0; j < kNumTilesU; ++j) {
520+
for (int j = 0; j < kNumTilesU - num_zero_tiles_u; ++j) {
487521
#pragma HLS PIPELINE II=1 style=frp
488522
for (int k = 0; k < params::G; ++k) {
489523
xu_streams[k] << hlsutils::adder_tree<ActivationType, params::Tu>(
@@ -493,23 +527,23 @@ void KernelU_Pruned(const int num_active_inputs,
493527
}
494528

495529
// ===========================================================================
496-
// TODO: Same as non-pruned version -> wrap into a function
530+
// TODO: Same as non-pruned version -> wrap into a function (be careful to NTv-ZTv)
497531
// ===========================================================================
498532
int iter_cnt = 0;
499533
XU_DMA:
500534
for (int i = 0; i < R_max; ++i) {
501535
typename params::VectG_Type xu_out[params::N] = {typename params::VectG_Type(0)};
502536
#pragma HLS ARRAY_PARTITION variable=xu_out complete dim=1
503-
for (int j = 0; j < kNumTilesU; ++j) {
537+
for (int j = 0; j < kNumTilesU - num_zero_tiles_u; ++j) {
504538
for (int k = 0; k < num_active_inputs; ++k) {
505539
#pragma HLS PIPELINE II=1 style=frp
506540
for (int ii = 0; ii < params::G; ++ii) {
507-
if (i < num_refinements[k]) {
541+
if (i < num_refinements_xu_dma[k]) {
508542
xu_out[k][ii] += xu_streams[ii].read();
509543
#pragma HLS BIND_OP variable=xu_out[k][ii] op=add impl=dsp
510544
}
511545
}
512-
if (i < num_refinements[k] && j == kNumTilesU - 1) {
546+
if (i < num_refinements_xu_dma[k] && j == kNumTilesU - num_zero_tiles_u - 1) {
513547
const bool kIsLast = iter_cnt == R_total - 1;
514548
xu_axis.template PushVector<ActivationType, params::G>(xu_out[k], kIsLast);
515549
++iter_cnt;
@@ -524,14 +558,14 @@ void KernelU_Pruned(const int num_active_inputs,
524558

525559
namespace testu {
526560

527-
static const int kNumInputs = 4;
561+
static const int kNumInputs = 2;
528562
static const int kInputSize = 1024;
529563
static const int Tu = 4;
564+
static const int ZTu = 0;
530565
// NOTE: The rest of the parameters are unused for now.
531566
static const int kDummySize = 1;
532567
static const int R = 8;
533568
static const int Tv = 1;
534-
static const int ZTu = 0;
535569
static const int ZTv = 0;
536570
static const int G = 4;
537571

@@ -579,11 +613,20 @@ void HlsKernelU(const int num_active_inputs,
579613
const int num_refinements[testu::params::N],
580614
const bool pad_output,
581615
// const int num_zero_tiles_u,
582-
// hls::stream<ap_uint<testu::NumTuBits> >& uz_idx_port,
616+
// hls::stream<ap_uint<testu::NumTuBits> >& unz_idx_port,
583617
hls::stream<typename testu::params::VectTuAxiPacketType>& x_port,
584618
hls::stream<typename testu::params::VectTuAxiPacketType>& u_port,
585619
hls::stream<typename testu::params::VectG_AxiPacketType>& xu_port);
586620

621+
void HlsKernelU_Pruned(const int num_active_inputs,
622+
const int input_size,
623+
const int num_refinements[testu::params::N],
624+
const int num_zero_tiles_u,
625+
hls::stream<typename testu::params::VectGZTuAxiPacketType>& unz_idx_port,
626+
hls::stream<typename testu::params::VectTuAxiPacketType>& x_port,
627+
hls::stream<typename testu::params::VectTuAxiPacketType>& u_port,
628+
hls::stream<typename testu::params::VectG_AxiPacketType>& xu_port);
629+
587630
#endif // end __VITIS_HLS__
588631

589632
#endif // end KERNEL_U_KERNEL_H_

0 commit comments

Comments
 (0)