@@ -354,13 +354,14 @@ void KernelU_Pruned(const int num_active_inputs,
354354 const int input_size,
355355 const int num_refinements[params::N],
356356 const int num_zero_tiles_u,
357- hls::stream<typename params::VectGZTuAxiPacketType>& uz_idx_port ,
357+ hls::stream<typename params::VectGZTuAxiPacketType>& unz_idx_port ,
358358 hls::stream<typename params::VectTuAxiPacketType>& x_port,
359359 hls::stream<typename params::VectTuAxiPacketType>& u_port,
360360 hls::stream<typename WrapperAxisG::PacketType>& xu_port) {
361361#pragma HLS TOP name=KernelU
362362#pragma HLS DATAFLOW
363- #pragma HLS INLINE
363+ // #pragma HLS INLINE
364+ #pragma HLS STABLE variable=unz_idx_port
364365#pragma HLS STABLE variable=x_port
365366#pragma HLS STABLE variable=u_port
366367#pragma HLS STABLE variable=xu_port
@@ -377,7 +378,7 @@ void KernelU_Pruned(const int num_active_inputs,
377378 assert (input_size % params::Tu == 0 );
378379 assert (input_size <= params::I);
379380 assert (kNumTilesU <= kMaxNumTilesU );
380- auto uz_axis = svd::AxiStreamPort<params::NumGTuBitsAligned>(uz_idx_port );
381+ auto uz_axis = svd::AxiStreamPort<params::NumGTuBitsAligned>(unz_idx_port );
381382 auto x_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(x_port);
382383 auto u_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(u_port);
383384 auto xu_axis = svd::AxiStreamInterface<WrapperAxisG>(xu_port);
@@ -407,19 +408,35 @@ void KernelU_Pruned(const int num_active_inputs,
407408 *
408409 * R_total = 2 * 3 + (3-2) * (3-1) + (6-3) * (3-2)
409410 */
411+ int num_refinements_init[params::N];
412+ int num_refinements_x_dma[params::N];
413+ int num_refinements_u_dma[params::N];
414+ int num_refinements_xu_dma[params::N];
415+ #pragma HLS ARRAY_PARTITION variable=num_refinements_init complete
416+ #pragma HLS ARRAY_PARTITION variable=num_refinements_x_dma complete
417+ #pragma HLS ARRAY_PARTITION variable=num_refinements_u_dma complete
418+ #pragma HLS ARRAY_PARTITION variable=num_refinements_xu_dma complete
419+ for (int i = 0 ; i < params::N; ++i) {
420+ #pragma HLS UNROLL
421+ num_refinements_init[i] = num_refinements[i];
422+ num_refinements_x_dma[i] = num_refinements[i];
423+ num_refinements_u_dma[i] = num_refinements[i];
424+ num_refinements_xu_dma[i] = num_refinements[i];
425+ }
426+
410427 // ===========================================================================
411- // TODO: Same as non-pruned version -> wrap into a function
428+ // TODO: Same as non-pruned version -> wrap into a function (be careful to NTu-ZTu)
412429 // ===========================================================================
413- int R_max = num_refinements [0 ];
414- int R_total = num_refinements [0 ] * num_active_inputs; // Total elements.
430+ int R_max = num_refinements_init [0 ];
431+ int R_total = num_refinements_init [0 ] * num_active_inputs; // Total elements.
415432 Get_Total_R:
416433 for (int i = 1 ; i < num_active_inputs; ++i) {
417434#pragma HLS PIPELINE II=1 style=frp
418- if (num_refinements [i] > R_max) {
419- R_max = num_refinements [i];
435+ if (num_refinements_init [i] > R_max) {
436+ R_max = num_refinements_init [i];
420437 }
421- assert (num_refinements [i] >= num_refinements [i - 1 ]);
422- R_total += (num_refinements [i] - num_refinements [i - 1 ]) * (num_active_inputs - i);
438+ assert (num_refinements_init [i] >= num_refinements_init [i - 1 ]);
439+ R_total += (num_refinements_init [i] - num_refinements_init [i - 1 ]) * (num_active_inputs - i);
423440 }
424441
425442 // Added
@@ -435,44 +452,61 @@ void KernelU_Pruned(const int num_active_inputs,
435452 }
436453 }
437454
455+ typedef ap_uint<params::NumGTuBitsAligned> ZIndexType;
456+ auto get_idx = [](const ZIndexType nz_idx, const int i) {
457+ const int kHi = (i + 1 ) * params::NumTuBits - 1 ;
458+ const int kLo = i * params::NumTuBits;
459+ return nz_idx.range (kHi , kLo ).to_int ();
460+ };
461+ auto set_nz_idx = [](const int x) {
462+ #pragma HLS PIPELINE II=1 style=frp
463+ ZIndexType nz_idx;
464+ const auto tmp = ap_uint<params::NumTuBits>(x);
465+ for (int i = 0 ; i < params::G; ++i) {
466+ const int kHi = (i + 1 ) * params::NumTuBits - 1 ;
467+ const int kLo = i * params::NumTuBits;
468+ nz_idx.range (kHi , kLo ) = tmp.range ();
469+ }
470+ return nz_idx;
471+ };
438472 // Changed
439473 int R_prev = 0 ;
440- X_DMA_dispatcher :
474+ X_DMA_Dispatcher :
441475 for (int ii = 0 ; ii < num_active_inputs; ++ii) {
442- for (int i = 0 ; i < num_refinements[ii] - R_prev; ++i) {
443- assert (num_refinements[ii] - R_prev >= 1 );
444- for (int j = 0 ; j < kNumTilesU ; ++j) {
445- // Read z_idx
446- auto z_idx = uz_axis.template PopVector <ActivationType, params::G>();
476+ for (int i = 0 ; i < num_refinements_x_dma[ii] - R_prev; ++i) {
477+ assert (num_refinements_x_dma[ii] - R_prev >= 1 );
478+ for (int j = 0 ; j < kNumTilesU - num_zero_tiles_u; ++j) {
479+ auto nz_idx = num_zero_tiles_u > 0 ? uz_axis.template Pop <ZIndexType>() : set_nz_idx (j);
447480 for (int k = 0 ; k < num_active_inputs - ii; ++k) {
448- #pragma HLS PIPELINE II=1 style=frp
449481 assert (num_active_inputs - ii >= 1 );
450482 assert (k + ii < params::N);
451483 for (int kk = 0 ; kk < params::G; ++kk) {
484+ #pragma HLS LOOP_FLATTEN
485+ #pragma HLS PIPELINE II=1 style=frp
452486 typename params::VectTuType x_val;
453487 for (int jj = 0 ; jj < params::Tu; ++jj) {
454- x_val[jj] = x_buffer[k + ii][jj][z_idx[kk] ];
488+ x_val[jj] = x_buffer[k + ii][jj][get_idx (nz_idx, kk) ];
455489 }
456490 x_stream[kk] << x_val;
457491 }
458492 }
459493 }
460494 }
461- R_prev = num_refinements [ii];
495+ R_prev = num_refinements_x_dma [ii];
462496 }
463497
464498 // ===========================================================================
465- // TODO: Same as non-pruned version -> wrap into a function
499+ // TODO: Same as non-pruned version -> wrap into a function (be careful to NTv-ZTv)
466500 // ===========================================================================
467501 U_DMA:
468502 for (int i = 0 ; i < R_max; ++i) {
469503#pragma HLS LOOP_TRIPCOUNT min=params::R max=params::R
470- for (int j = 0 ; j < kNumTilesU ; ++j) {
504+ for (int j = 0 ; j < kNumTilesU - num_zero_tiles_u ; ++j) {
471505 for (int k = 0 ; k < params::G; ++k) {
472506 auto u_val = u_axis.template PopVector <ActivationType, params::Tu>();
473507 for (int ii = 0 ; ii < num_active_inputs; ++ii) {
474508#pragma HLS PIPELINE II=1 style=frp
475- if (i < num_refinements [ii]) {
509+ if (i < num_refinements_u_dma [ii]) {
476510 u_streams[k] << u_val;
477511 }
478512 }
@@ -483,7 +517,7 @@ void KernelU_Pruned(const int num_active_inputs,
483517 // Changed
484518 U_Kernel:
485519 for (int i = 0 ; i < R_total; ++i) {
486- for (int j = 0 ; j < kNumTilesU ; ++j) {
520+ for (int j = 0 ; j < kNumTilesU - num_zero_tiles_u ; ++j) {
487521#pragma HLS PIPELINE II=1 style=frp
488522 for (int k = 0 ; k < params::G; ++k) {
489523 xu_streams[k] << hlsutils::adder_tree<ActivationType, params::Tu>(
@@ -493,23 +527,23 @@ void KernelU_Pruned(const int num_active_inputs,
493527 }
494528
495529 // ===========================================================================
496- // TODO: Same as non-pruned version -> wrap into a function
530+ // TODO: Same as non-pruned version -> wrap into a function (be careful to NTv-ZTv)
497531 // ===========================================================================
498532 int iter_cnt = 0 ;
499533 XU_DMA:
500534 for (int i = 0 ; i < R_max; ++i) {
501535 typename params::VectG_Type xu_out[params::N] = {typename params::VectG_Type (0 )};
502536#pragma HLS ARRAY_PARTITION variable=xu_out complete dim=1
503- for (int j = 0 ; j < kNumTilesU ; ++j) {
537+ for (int j = 0 ; j < kNumTilesU - num_zero_tiles_u ; ++j) {
504538 for (int k = 0 ; k < num_active_inputs; ++k) {
505539#pragma HLS PIPELINE II=1 style=frp
506540 for (int ii = 0 ; ii < params::G; ++ii) {
507- if (i < num_refinements [k]) {
541+ if (i < num_refinements_xu_dma [k]) {
508542 xu_out[k][ii] += xu_streams[ii].read ();
509543#pragma HLS BIND_OP variable=xu_out[k][ii] op=add impl=dsp
510544 }
511545 }
512- if (i < num_refinements [k] && j == kNumTilesU - 1 ) {
546+ if (i < num_refinements_xu_dma [k] && j == kNumTilesU - num_zero_tiles_u - 1 ) {
513547 const bool kIsLast = iter_cnt == R_total - 1 ;
514548 xu_axis.template PushVector <ActivationType, params::G>(xu_out[k], kIsLast );
515549 ++iter_cnt;
@@ -524,14 +558,14 @@ void KernelU_Pruned(const int num_active_inputs,
524558
525559namespace testu {
526560
527- static const int kNumInputs = 4 ;
561+ static const int kNumInputs = 2 ;
528562static const int kInputSize = 1024 ;
529563static const int Tu = 4 ;
564+ static const int ZTu = 0 ;
530565// NOTE: The rest of the parameters are unused for now.
531566static const int kDummySize = 1 ;
532567static const int R = 8 ;
533568static const int Tv = 1 ;
534- static const int ZTu = 0 ;
535569static const int ZTv = 0 ;
536570static const int G = 4 ;
537571
@@ -579,11 +613,20 @@ void HlsKernelU(const int num_active_inputs,
579613 const int num_refinements[testu::params::N],
580614 const bool pad_output,
581615 // const int num_zero_tiles_u,
582- // hls::stream<ap_uint<testu::NumTuBits> >& uz_idx_port ,
616+ // hls::stream<ap_uint<testu::NumTuBits> >& unz_idx_port ,
583617 hls::stream<typename testu::params::VectTuAxiPacketType>& x_port,
584618 hls::stream<typename testu::params::VectTuAxiPacketType>& u_port,
585619 hls::stream<typename testu::params::VectG_AxiPacketType>& xu_port);
586620
621+ void HlsKernelU_Pruned (const int num_active_inputs,
622+ const int input_size,
623+ const int num_refinements[testu::params::N],
624+ const int num_zero_tiles_u,
625+ hls::stream<typename testu::params::VectGZTuAxiPacketType>& unz_idx_port,
626+ hls::stream<typename testu::params::VectTuAxiPacketType>& x_port,
627+ hls::stream<typename testu::params::VectTuAxiPacketType>& u_port,
628+ hls::stream<typename testu::params::VectG_AxiPacketType>& xu_port);
629+
587630#endif // end __VITIS_HLS__
588631
589632#endif // end KERNEL_U_KERNEL_H_
0 commit comments