@@ -199,7 +199,6 @@ void KernelU(const int num_active_inputs,
199199 const int input_size,
200200 const int num_refinements[params::N],
201201 const bool pad_output,
202- // hls::stream<typename params::IndexU_Type>& z_idx_port,
203202 hls::stream<typename params::VectTuAxiPacketType>& x_port,
204203 hls::stream<typename params::VectTuAxiPacketType>& u_port,
205204 hls::stream<typename WrapperAxisG::PacketType>& xu_port) {
@@ -210,13 +209,14 @@ void KernelU(const int num_active_inputs,
210209#pragma HLS STABLE variable=u_port
211210#pragma HLS STABLE variable=xu_port
212211 typedef typename params::ActivationD ActivationType;
213- const int kNumTilesU = input_size / params::Tu;
214- const int kMaxNumTilesU = params::I / params::Tu;
215- const int kStreamDepth_X = 2 + kMaxNumTilesU * params::N;
216- const int kStreamDepth_U = 8 + kMaxNumTilesU * params::N;
217- const int kStreamDepth_XU = 2 + params::G;
218- assert (num_active_inputs <= params::N);
212+ const unsigned int kNumTilesU = input_size / params::Tu;
213+ const unsigned int kMaxNumTilesU = params::I / params::Tu;
214+ const unsigned int kStreamDepth_X = 2 + kMaxNumTilesU * params::N;
215+ const unsigned int kStreamDepth_U = 8 + kMaxNumTilesU * params::N;
216+ const unsigned int kStreamDepth_XU = 2 + params::G;
219217 assert (num_active_inputs > 0 );
218+ assert (kNumTilesU > 0 );
219+ assert (num_active_inputs <= params::N);
220220 assert (params::I % params::Tu == 0 );
221221 assert (input_size % params::Tu == 0 );
222222 assert (input_size <= params::I);
@@ -344,6 +344,180 @@ void KernelU(const int num_active_inputs,
344344 }
345345 }
346346}
347+
348+
349+ template <
350+ typename params,
351+ typename WrapperAxisG = svd::AxiStreamPort<params::VectG_AxiWidth>
352+ >
353+ void KernelU_Pruned (const int num_active_inputs,
354+ const int input_size,
355+ const int num_refinements[params::N],
356+ const int num_zero_tiles_u,
357+ hls::stream<typename params::VectGZTuAxiPacketType>& uz_idx_port,
358+ hls::stream<typename params::VectTuAxiPacketType>& x_port,
359+ hls::stream<typename params::VectTuAxiPacketType>& u_port,
360+ hls::stream<typename WrapperAxisG::PacketType>& xu_port) {
361+ #pragma HLS TOP name=KernelU
362+ #pragma HLS DATAFLOW
363+ #pragma HLS INLINE
364+ #pragma HLS STABLE variable=x_port
365+ #pragma HLS STABLE variable=u_port
366+ #pragma HLS STABLE variable=xu_port
367+ typedef typename params::ActivationD ActivationType;
368+ const unsigned int kNumTilesU = input_size / params::Tu;
369+ const unsigned int kMaxNumTilesU = params::I / params::Tu;
370+ const unsigned int kStreamDepth_X = 2 + kMaxNumTilesU * params::N;
371+ const unsigned int kStreamDepth_U = 8 + kMaxNumTilesU * params::N;
372+ const unsigned int kStreamDepth_XU = 2 + params::G;
373+ assert (num_active_inputs > 0 );
374+ assert (kNumTilesU > 0 );
375+ assert (num_active_inputs <= params::N);
376+ assert (params::I % params::Tu == 0 );
377+ assert (input_size % params::Tu == 0 );
378+ assert (input_size <= params::I);
379+ assert (kNumTilesU <= kMaxNumTilesU );
380+ auto uz_axis = svd::AxiStreamPort<params::NumGTuBitsAligned>(uz_idx_port);
381+ auto x_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(x_port);
382+ auto u_axis = svd::AxiStreamPort<params::VectTuAxiWidth>(u_port);
383+ auto xu_axis = svd::AxiStreamInterface<WrapperAxisG>(xu_port);
384+ hls::stream<typename params::VectTuType> x_stream[params::G];
385+ hls::stream<typename params::VectTuType> u_streams[params::G];
386+ hls::stream<ActivationType> xu_streams[params::G];
387+ ActivationType x_buffer[params::N][params::Tu][kMaxNumTilesU ];
388+ #pragma HLS STREAM variable=x_stream depth=kStreamDepth_X
389+ #pragma HLS STREAM variable=u_streams depth=kStreamDepth_U
390+ #pragma HLS STREAM variable=xu_streams depth=kStreamDepth_XU
391+ #pragma HLS ARRAY_PARTITION variable=u_streams complete dim=1
392+ #pragma HLS ARRAY_PARTITION variable=x_buffer complete dim=1
393+ #pragma HLS ARRAY_PARTITION variable=x_buffer complete dim=2
394+ #pragma HLS BIND_STORAGE variable=x_buffer type=ram_t2p impl=bram latency=1
395+ /*
396+ * Ideally, if the Rs are ordered, it would be: R0 * N + (R1-R0) * (N-1) +
397+ * (R2-R1) * (N-2)
398+ *
399+ * Imagine we have: R0 = 2, R1 = 3, R2 = 6
400+ *
401+ * This means:
402+ * - till refinement 2 we have input 0 to process
403+ * - till refinement 3 we have input 1 to process
404+ * - till refinement 6 we have input 2 to process
405+ *
406+ * So it would become:
407+ *
408+ * R_total = 2 * 3 + (3-2) * (3-1) + (6-3) * (3-2)
409+ */
410+ // ===========================================================================
411+ // TODO: Same as non-pruned version -> wrap into a function
412+ // ===========================================================================
413+ int R_max = num_refinements[0 ];
414+ int R_total = num_refinements[0 ] * num_active_inputs; // Total elements.
415+ Get_Total_R:
416+ for (int i = 1 ; i < num_active_inputs; ++i) {
417+ #pragma HLS PIPELINE II=1 style=frp
418+ if (num_refinements[i] > R_max) {
419+ R_max = num_refinements[i];
420+ }
421+ assert (num_refinements[i] >= num_refinements[i - 1 ]);
422+ R_total += (num_refinements[i] - num_refinements[i - 1 ]) * (num_active_inputs - i);
423+ }
424+
425+ // Added
426+ X_DAM_in:
427+ for (int i = 0 ; i < num_active_inputs; ++i) {
428+ for (int j = 0 ; j < kNumTilesU ; ++j) {
429+ #pragma HLS LOOP_FLATTEN
430+ #pragma HLS PIPELINE II=1 style=frp
431+ auto x_val = x_axis.template PopVector <ActivationType, params::Tu>();
432+ for (int k = 0 ; k < params::Tu; ++k) {
433+ x_buffer[i][k][j] = x_val[k];
434+ }
435+ }
436+ }
437+
438+ // Changed
439+ int R_prev = 0 ;
440+ X_DMA_dispatcher:
441+ for (int ii = 0 ; ii < num_active_inputs; ++ii) {
442+ for (int i = 0 ; i < num_refinements[ii] - R_prev; ++i) {
443+ assert (num_refinements[ii] - R_prev >= 1 );
444+ for (int j = 0 ; j < kNumTilesU ; ++j) {
445+ // Read z_idx
446+ auto z_idx = uz_axis.template PopVector <ActivationType, params::G>();
447+ for (int k = 0 ; k < num_active_inputs - ii; ++k) {
448+ #pragma HLS PIPELINE II=1 style=frp
449+ assert (num_active_inputs - ii >= 1 );
450+ assert (k + ii < params::N);
451+ for (int kk = 0 ; kk < params::G; ++kk) {
452+ typename params::VectTuType x_val;
453+ for (int jj = 0 ; jj < params::Tu; ++jj) {
454+ x_val[jj] = x_buffer[k + ii][jj][z_idx[kk]];
455+ }
456+ x_stream[kk] << x_val;
457+ }
458+ }
459+ }
460+ }
461+ R_prev = num_refinements[ii];
462+ }
463+
464+ // ===========================================================================
465+ // TODO: Same as non-pruned version -> wrap into a function
466+ // ===========================================================================
467+ U_DMA:
468+ for (int i = 0 ; i < R_max; ++i) {
469+ #pragma HLS LOOP_TRIPCOUNT min=params::R max=params::R
470+ for (int j = 0 ; j < kNumTilesU ; ++j) {
471+ for (int k = 0 ; k < params::G; ++k) {
472+ auto u_val = u_axis.template PopVector <ActivationType, params::Tu>();
473+ for (int ii = 0 ; ii < num_active_inputs; ++ii) {
474+ #pragma HLS PIPELINE II=1 style=frp
475+ if (i < num_refinements[ii]) {
476+ u_streams[k] << u_val;
477+ }
478+ }
479+ }
480+ }
481+ }
482+
483+ // Changed
484+ U_Kernel:
485+ for (int i = 0 ; i < R_total; ++i) {
486+ for (int j = 0 ; j < kNumTilesU ; ++j) {
487+ #pragma HLS PIPELINE II=1 style=frp
488+ for (int k = 0 ; k < params::G; ++k) {
489+ xu_streams[k] << hlsutils::adder_tree<ActivationType, params::Tu>(
490+ x_stream[k].read () * u_streams[k].read ());
491+ }
492+ }
493+ }
494+
495+ // ===========================================================================
496+ // TODO: Same as non-pruned version -> wrap into a function
497+ // ===========================================================================
498+ int iter_cnt = 0 ;
499+ XU_DMA:
500+ for (int i = 0 ; i < R_max; ++i) {
501+ typename params::VectG_Type xu_out[params::N] = {typename params::VectG_Type (0 )};
502+ #pragma HLS ARRAY_PARTITION variable=xu_out complete dim=1
503+ for (int j = 0 ; j < kNumTilesU ; ++j) {
504+ for (int k = 0 ; k < num_active_inputs; ++k) {
505+ #pragma HLS PIPELINE II=1 style=frp
506+ for (int ii = 0 ; ii < params::G; ++ii) {
507+ if (i < num_refinements[k]) {
508+ xu_out[k][ii] += xu_streams[ii].read ();
509+ #pragma HLS BIND_OP variable=xu_out[k][ii] op=add impl=dsp
510+ }
511+ }
512+ if (i < num_refinements[k] && j == kNumTilesU - 1 ) {
513+ const bool kIsLast = iter_cnt == R_total - 1 ;
514+ xu_axis.template PushVector <ActivationType, params::G>(xu_out[k], kIsLast );
515+ ++iter_cnt;
516+ }
517+ }
518+ }
519+ }
520+ }
347521#endif // end __VITIS_HLS__
348522
349523} // svd
@@ -404,6 +578,8 @@ void HlsKernelU(const int num_active_inputs,
404578 const int input_size,
405579 const int num_refinements[testu::params::N],
406580 const bool pad_output,
581+ // const int num_zero_tiles_u,
582+ // hls::stream<ap_uint<testu::NumTuBits> >& uz_idx_port,
407583 hls::stream<typename testu::params::VectTuAxiPacketType>& x_port,
408584 hls::stream<typename testu::params::VectTuAxiPacketType>& u_port,
409585 hls::stream<typename testu::params::VectG_AxiPacketType>& xu_port);
0 commit comments