55#include < algorithm>
66#include < cstring>
77
8+ #define ENABLE_SPARSITY true
89// Include random header only when sparsity is enabled
910#ifdef ENABLE_SPARSITY
1011#include < random>
@@ -20,7 +21,7 @@ using float32_t = float;
2021// Configuration Macros
2122// ============================================================================
2223#ifndef NUM_THREADS
23- #define NUM_THREADS 16 // Should be 32 for paper accuracy
24+ #define NUM_THREADS 8 // Should be 32 for paper accuracy
2425#endif
2526
2627#ifndef XLENB
@@ -372,7 +373,7 @@ class WMMA {
372373
373374 FragA fragA_compressed_; // Compressed matrix A (50% storage)
374375 FragA_meta fragA_meta_; // Metadata: 1 = non-zero, 0 = pruned
375- vector_t <uint32_t , 8 > packed_bit_meta_; // Packed bitmap metadata
376+ vector_t <uint32_t , NRA > packed_bit_meta_; // Packed bitmap metadata, 32 * 32bit Reg in RISC-V
376377#endif
377378
378379 FragD fragRef_;
@@ -441,7 +442,7 @@ class WMMA {
441442
442443 // Pack metadata for this chunk
443444 for (uint32_t r_i = 0 ; r_i < ROWS_PER_CHUNK; ++r_i) {
444- for (uint32_t c_i = 0 ; c_i < ELEMENTS_PER_ROb_W ; ++c_i) {
445+ for (uint32_t c_i = 0 ; c_i < ELEMENTS_PER_ROW ; ++c_i) {
445446 uint32_t row = r_i + chunk * ROWS_PER_CHUNK + m * tcM;
446447 uint32_t col = c_i + k * ELEMENTS_PER_ROW;
447448
@@ -470,7 +471,7 @@ class WMMA {
470471
471472 // Gather B column elements based on A's sparsity pattern
472473 void gather_sparse_B_column (
473- uint8_t *b_collected,
474+ It *b_collected,
474475 const Xt *b_col_0,
475476 const Xt *b_col_1,
476477 uint16_t a_row_meta) const {
@@ -485,7 +486,7 @@ class WMMA {
485486 uint32_t element_idx = bit_idx / i_ratio;
486487 uint32_t byte_pos = (bit_idx % i_ratio) * 8 ;
487488 b_collected[collect_idx++] =
488- static_cast <uint8_t >((b_col_0[element_idx] >> byte_pos) & 0xFF );
489+ static_cast <It >((b_col_0[element_idx] >> byte_pos) & 0xFF );
489490 }
490491 }
491492
@@ -694,7 +695,7 @@ class WMMA {
694695 uint32_t b_off = (n % b_sub_blocks) * b_block_size;
695696
696697 Vreg vd;
697- uint8_t b_col_collected[tcK * i_ratio];
698+ It b_col_collected[tcK * i_ratio];
698699
699700 for (uint32_t i = 0 ; i < tcM; ++i) {
700701 for (uint32_t j = 0 ; j < tcN; ++j) {
0 commit comments