Skip to content

Commit 0e4200e

Browse files
committed
generalize for int8, NT=8
1 parent 55354eb commit 0e4200e

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/regression/sgemm_tcu/tensor_generic.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <algorithm>
66
#include <cstring>
77

8+
#define ENABLE_SPARSITY true
89
// Include random header only when sparsity is enabled
910
#ifdef ENABLE_SPARSITY
1011
#include <random>
@@ -20,7 +21,7 @@ using float32_t = float;
2021
// Configuration Macros
2122
// ============================================================================
2223
#ifndef NUM_THREADS
23-
#define NUM_THREADS 16 // Should be 32 for paper accuracy
24+
#define NUM_THREADS 8 // Should be 32 for paper accuracy
2425
#endif
2526

2627
#ifndef XLENB
@@ -372,7 +373,7 @@ class WMMA {
372373

373374
FragA fragA_compressed_; // Compressed matrix A (50% storage)
374375
FragA_meta fragA_meta_; // Metadata: 1 = non-zero, 0 = pruned
375-
vector_t<uint32_t, 8> packed_bit_meta_; // Packed bitmap metadata
376+
vector_t<uint32_t, NRA> packed_bit_meta_; // Packed bitmap metadata, 32 * 32bit Reg in RISC-V
376377
#endif
377378

378379
FragD fragRef_;
@@ -441,7 +442,7 @@ class WMMA {
441442

442443
// Pack metadata for this chunk
443444
for (uint32_t r_i = 0; r_i < ROWS_PER_CHUNK; ++r_i) {
444-
for (uint32_t c_i = 0; c_i < ELEMENTS_PER_ROb_W; ++c_i) {
445+
for (uint32_t c_i = 0; c_i < ELEMENTS_PER_ROW; ++c_i) {
445446
uint32_t row = r_i + chunk * ROWS_PER_CHUNK + m * tcM;
446447
uint32_t col = c_i + k * ELEMENTS_PER_ROW;
447448

@@ -470,7 +471,7 @@ class WMMA {
470471

471472
// Gather B column elements based on A's sparsity pattern
472473
void gather_sparse_B_column(
473-
uint8_t *b_collected,
474+
It *b_collected,
474475
const Xt *b_col_0,
475476
const Xt *b_col_1,
476477
uint16_t a_row_meta) const {
@@ -485,7 +486,7 @@ class WMMA {
485486
uint32_t element_idx = bit_idx / i_ratio;
486487
uint32_t byte_pos = (bit_idx % i_ratio) * 8;
487488
b_collected[collect_idx++] =
488-
static_cast<uint8_t>((b_col_0[element_idx] >> byte_pos) & 0xFF);
489+
static_cast<It>((b_col_0[element_idx] >> byte_pos) & 0xFF);
489490
}
490491
}
491492

@@ -694,7 +695,7 @@ class WMMA {
694695
uint32_t b_off = (n % b_sub_blocks) * b_block_size;
695696

696697
Vreg vd;
697-
uint8_t b_col_collected[tcK * i_ratio];
698+
It b_col_collected[tcK * i_ratio];
698699

699700
for (uint32_t i = 0; i < tcM; ++i) {
700701
for (uint32_t j = 0; j < tcN; ++j) {

0 commit comments

Comments
 (0)