@@ -138,6 +138,8 @@ struct wmma_context {
138138 static constexpr frag_use_t Use = U;
139139 static constexpr uint32_t NR = N;
140140 std::array<vreg_t , N> data;
141+ using metadata_array_t = std::conditional_t <U == matrix_a, std::array<uint32_t , N>, std::array<uint32_t , 0 >>;
142+ metadata_array_t metadata{};
141143 };
142144
143145public:
@@ -175,7 +177,7 @@ struct wmma_context {
175177 }
176178
177179 template <mem_layout src_layout = row_major, typename Frag>
178- static __attribute__ ((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) {
180+ static __attribute__ ((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm, const void *meta_src = nullptr ) {
179181 uint32_t lane = vx_thread_id ();
180182 if constexpr (Frag::Use == matrix_a) {
181183 // Load row-major matrix A
@@ -188,24 +190,46 @@ struct wmma_context {
188190 if constexpr (src_layout == col_major) {
189191 std::swap (block_row, block_col);
190192 }
191- auto base = reinterpret_cast <const input_t *>(src) + block_row * ldm + block_col;
193+ // For sparse format: when meta_src is provided, data stride is K/2 (not K)
194+ // because each row has K/2 values (2 per block of 4)
195+ size_t data_ldm = (meta_src != nullptr ) ? (ldm / 2 ) : ldm;
196+ auto base = reinterpret_cast <const input_t *>(src) + block_row * data_ldm + block_col;
197+ const uint8_t * meta_base = meta_src ? reinterpret_cast <const uint8_t *>(meta_src) : nullptr ;
198+ uint32_t meta_ldm = meta_src ? (ldm / 4 ) : 0 ; // Number of metadata bytes per row (K/4 blocks)
199+
192200 detail::unroll_for<Frag::NR>([&](auto r) {
193201 uint32_t block_m = r / cfg::k_steps;
194202 uint32_t block_k = r % cfg::k_steps;
195203 uint32_t elem_row = block_m * m_stride;
196204 uint32_t elem_col = block_k * k_stride;
205+ uint32_t meta_value = 0 ;
206+
207+ if (meta_base) {
208+ uint32_t matrix_row = block_row + elem_row;
209+ uint32_t k_elem_idx = elem_col / i_ratio;
210+ uint32_t meta_block_k = k_elem_idx / 4 ;
211+ if (meta_block_k < meta_ldm) {
212+ uint32_t meta_offset = matrix_row * meta_ldm + meta_block_k;
213+ meta_value = static_cast <uint32_t >(meta_base[meta_offset]);
214+ }
215+ }
216+
217+ if constexpr (Frag::Use == matrix_a) {
218+ dst.metadata [r] = meta_value;
219+ }
197220 if constexpr (src_layout == col_major) {
198221 static_assert (input_is_subbyte == false , " col_major layout is not supported for sub-byte matrix_a" );
199222 std::swap (elem_row, elem_col);
200- auto ptr = base + elem_row * ldm + elem_col;
223+ auto ptr = base + elem_row * data_ldm + elem_col;
201224 if constexpr (sizeof (vreg_t ) == sizeof (input_t ) && !input_is_subbyte) {
202225 dst.data [r] = *reinterpret_cast <const vreg_t *>(ptr);
203226 } else {
204- dst.data [r] = input_acessor_t::pack_row (ptr, ldm );
227+ dst.data [r] = input_acessor_t::pack_row (ptr, data_ldm );
205228 }
206229 } else {
207- // raw_major layout
208- auto ptr = base + elem_row * ldm + elem_col;
230+ // row_major layout
231+ // For sparse format, use data_ldm (K/2) instead of ldm (K)
232+ auto ptr = base + elem_row * data_ldm + elem_col;
209233 assert (reinterpret_cast <uintptr_t >(ptr) % alignof (vreg_t ) == 0 && " pointer must be aligned to 4 bytes" );
210234 dst.data [r] = *reinterpret_cast <const vreg_t *>(ptr);
211235 }
@@ -310,6 +334,24 @@ struct wmma_context {
310334 static_assert (FragC::Use == accumulator, " C must be accumulator" );
311335 static_assert (FragD::Use == accumulator, " D must be accumulator" );
312336
337+ auto meta_value = [&](uint32_t idx) -> uint32_t {
338+ if constexpr (FragA::Use == matrix_a) {
339+ if (idx < FragA::NR) {
340+ return fragA.metadata [idx];
341+ }
342+ }
343+ return 0u ;
344+ };
345+
346+ register uint32_t ma0 __asm__ (" a0" ) = meta_value (0 );
347+ register uint32_t ma1 __asm__ (" a1" ) = meta_value (1 );
348+ register uint32_t ma2 __asm__ (" a2" ) = meta_value (2 );
349+ register uint32_t ma3 __asm__ (" a3" ) = meta_value (3 );
350+ register uint32_t ma4 __asm__ (" a4" ) = meta_value (4 );
351+ register uint32_t ma5 __asm__ (" a5" ) = meta_value (5 );
352+ register uint32_t ma6 __asm__ (" a6" ) = meta_value (6 );
353+ register uint32_t ma7 __asm__ (" a7" ) = meta_value (7 );
354+
313355 // fragA: caller-saved registers (f0-f7)
314356 register float fa0 __asm__ (" f0" ) = fragA.data [0 ];
315357 register float fa1 __asm__ (" f1" ) = fragA.data [1 ];
@@ -348,15 +390,16 @@ struct wmma_context {
348390 register float fd3 __asm__ (" f27" );
349391 register float fd4 __asm__ (" f28" );
350392 register float fd5 __asm__ (" f29" );
351- register float fd6 __asm__ (" f30" );
393+ register float fd6 __asm__ (" f30" );
352394 register float fd7 __asm__ (" f31" );
353395
354- __asm__ volatile (" .insn r %[insn], 0, 2 , x%[fmd], x%[fms], x0"
396+ __asm__ volatile (" .insn r %[insn], 0, 3 , x%[fmd], x%[fms], x0"
355397 : " =f" (fd0), " =f" (fd1), " =f" (fd2), " =f" (fd3), " =f" (fd4), " =f" (fd5), " =f" (fd6), " =f" (fd7)
356398 : [insn]" i" (RISCV_CUSTOM0), [fmd]" i" (Ot::id), [fms]" i" (It::id),
357399 " f" (fa0), " f" (fa1), " f" (fa2), " f" (fa3), " f" (fa4), " f" (fa5), " f" (fa6), " f" (fa7),
358400 " f" (fb0), " f" (fb1), " f" (fb2), " f" (fb3), " f" (fb4), " f" (fb5), " f" (fb6), " f" (fb7),
359- " f" (fc0), " f" (fc1), " f" (fc2), " f" (fc3), " f" (fc4), " f" (fc5), " f" (fc6), " f" (fc7)
401+ " f" (fc0), " f" (fc1), " f" (fc2), " f" (fc3), " f" (fc4), " f" (fc5), " f" (fc6), " f" (fc7),
402+ " r" (ma0), " r" (ma1), " r" (ma2), " r" (ma3), " r" (ma4), " r" (ma5), " r" (ma6), " r" (ma7)
360403 );
361404
362405 // Write results to fragD
@@ -389,12 +432,13 @@ struct wmma_context {
389432 register float fd6 __asm__ (" f16" );
390433 register float fd7 __asm__ (" f17" );
391434
392- __asm__ volatile (" .insn r %[insn], 0, 2 , x%[fmd], x%[fms], x0"
435+ __asm__ volatile (" .insn r %[insn], 0, 3 , x%[fmd], x%[fms], x0"
393436 : " =f" (fd0), " =f" (fd1), " =f" (fd2), " =f" (fd3), " =f" (fd4), " =f" (fd5), " =f" (fd6), " =f" (fd7)
394437 : [insn]" i" (RISCV_CUSTOM0), [fmd]" i" (Ot::id), [fms]" i" (It::id),
395438 " f" (fa0), " f" (fa1), " f" (fa2), " f" (fa3), " f" (fa4), " f" (fa5), " f" (fa6), " f" (fa7),
396439 " f" (fb0), " f" (fb1), " f" (fb2), " f" (fb3),
397- " f" (fc0), " f" (fc1), " f" (fc2), " f" (fc3), " f" (fc4), " f" (fc5), " f" (fc6), " f" (fc7)
440+ " f" (fc0), " f" (fc1), " f" (fc2), " f" (fc3), " f" (fc4), " f" (fc5), " f" (fc6), " f" (fc7),
441+ " r" (ma0), " r" (ma1), " r" (ma2), " r" (ma3), " r" (ma4), " r" (ma5), " r" (ma6), " r" (ma7)
398442 );
399443
400444 // Write results to fragD
0 commit comments