Skip to content

Commit fda5ccf

Browse files
Added support for loading sparse matrix (#301)
1 parent a9cd0bd commit fda5ccf

File tree

16 files changed

+1038
-125
lines changed

16 files changed

+1038
-125
lines changed

kernel/include/vx_intrinsics.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,70 @@ inline __attribute__((const)) int vx_shfl_idx(size_t value, int bval, int cval,
281281
return ret;
282282
}
283283

284+
// -----------------------------------------------------------------------------
285+
// VEGETA tile memory operations (Load/Store)
286+
// -----------------------------------------------------------------------------
287+
288+
// TILE LOAD T: Load 1KB from ptr[TILE] to tile register index 'dst_treg'
289+
// Each load uses I-type encoding: rd=dst tile index, rs1=src_gpr, imm=ptr immediate
290+
inline void vx_lt(int dst_treg, int src_gpr, size_t ptr_imm) {
291+
__asm__ volatile (".insn i %0, 0, x%1, %2, %3"
292+
:: "i"(RISCV_CUSTOM1), "i"(dst_treg), "r"(src_gpr), "i"(ptr_imm) : "memory");
293+
}
294+
295+
// TILE LOAD U: Load 1KB from ptr[TILE] to ureg index 'dst_ureg'
296+
inline void vx_lu(int dst_ureg, int src_gpr, size_t ptr_imm) {
297+
__asm__ volatile (".insn i %0, 1, x%1, %2, %3"
298+
:: "i"(RISCV_CUSTOM1), "i"(dst_ureg), "r"(src_gpr), "i"(ptr_imm) : "memory");
299+
}
300+
301+
// TILE LOAD V: Load 1KB from ptr[TILE] to vreg index 'dst_vreg'
302+
inline void vx_lv(int dst_vreg, int src_gpr, size_t ptr_imm) {
303+
__asm__ volatile (".insn i %0, 2, x%1, %2, %3"
304+
:: "i"(RISCV_CUSTOM1), "i"(dst_vreg), "r"(src_gpr), "i"(ptr_imm) : "memory");
305+
}
306+
307+
// TILE LOAD M: Load 1KB from ptr[TILE] to mreg index 'dst_mreg'
308+
inline void vx_lm(int dst_mreg, int src_gpr, size_t ptr_imm) {
309+
__asm__ volatile (".insn i %0, 3, x%1, %2, %3"
310+
:: "i"(RISCV_CUSTOM1), "i"(dst_mreg), "r"(src_gpr), "i"(ptr_imm) : "memory");
311+
}
312+
313+
// TILE STORE T: Store 1KB from treg index 'src_treg' to ptr[TILE]
314+
// Store uses S-type encoding: rs1=src_gpr, rs2=src_treg index, imm=ptr immediate
315+
inline void vx_st(int src_gpr, size_t ptr_imm, int src_treg) {
316+
__asm__ volatile (".insn s %0, 0, %1, x%2, %3"
317+
:: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(src_treg), "i"(ptr_imm) : "memory");
318+
}
319+
320+
// -----------------------------------------------------------------------------
321+
// VEGETA tile compute (GEMM variants)
322+
// -----------------------------------------------------------------------------
323+
324+
// TGEMM: Multiply dense tile src1 with dense tile src2, accumulate into dst
325+
inline void vx_tgemm(int dst_treg, int src1_treg, int src2_treg) {
326+
__asm__ volatile (".insn r %0, 0, 0, x%1, x%2, x%3"
327+
:: "i"(RISCV_CUSTOM3), "i"(dst_treg), "i"(src1_treg), "i"(src2_treg));
328+
}
329+
330+
// UGEMM: Multiply sparse (2:4) tile src1 with dense tile src2, accumulate into dst
331+
inline void vx_ugemm(int dst_treg, int src1_treg, int src2_ureg) {
332+
__asm__ volatile (".insn r %0, 0, 1, x%1, x%2, x%3"
333+
:: "i"(RISCV_CUSTOM3), "i"(dst_treg), "i"(src1_treg), "i"(src2_ureg));
334+
}
335+
336+
// VGEMM: Multiply sparse (1:4) tile src1 with dense tile src2, accumulate into dst
337+
inline void vx_vgemm(int dst_treg, int src1_treg, int src2_vreg) {
338+
__asm__ volatile (".insn r %0, 0, 2, x%1, x%2, x%3"
339+
:: "i"(RISCV_CUSTOM3), "i"(dst_treg), "i"(src1_treg), "i"(src2_vreg));
340+
}
341+
342+
// RGEMM: Multiply sparse (row-wise N:4) tile src1 with dense tile src2, accumulate into dst
343+
inline void vx_rgemm(int dst_ureg, int src1_treg, int src2_ureg) {
344+
__asm__ volatile (".insn r %0, 0, 3, x%1, x%2, x%3"
345+
:: "i"(RISCV_CUSTOM3), "i"(dst_ureg), "i"(src1_treg), "i"(src2_ureg));
346+
}
347+
284348
#ifdef __cplusplus
285349
}
286350
#endif

kernel/include/vx_sparse.h

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ struct wmma_context {
138138
static constexpr frag_use_t Use = U;
139139
static constexpr uint32_t NR = N;
140140
std::array<vreg_t, N> data;
141+
using metadata_array_t = std::conditional_t<U == matrix_a, std::array<uint32_t, N>, std::array<uint32_t, 0>>;
142+
metadata_array_t metadata{};
141143
};
142144

143145
public:
@@ -175,7 +177,7 @@ struct wmma_context {
175177
}
176178

177179
template <mem_layout src_layout = row_major, typename Frag>
178-
static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) {
180+
static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm, const void *meta_src = nullptr) {
179181
uint32_t lane = vx_thread_id();
180182
if constexpr (Frag::Use == matrix_a) {
181183
// Load row-major matrix A
@@ -188,24 +190,46 @@ struct wmma_context {
188190
if constexpr (src_layout == col_major) {
189191
std::swap(block_row, block_col);
190192
}
191-
auto base = reinterpret_cast<const input_t*>(src) + block_row * ldm + block_col;
193+
// For sparse format: when meta_src is provided, data stride is K/2 (not K)
194+
// because each row has K/2 values (2 per block of 4)
195+
size_t data_ldm = (meta_src != nullptr) ? (ldm / 2) : ldm;
196+
auto base = reinterpret_cast<const input_t*>(src) + block_row * data_ldm + block_col;
197+
const uint8_t* meta_base = meta_src ? reinterpret_cast<const uint8_t*>(meta_src) : nullptr;
198+
uint32_t meta_ldm = meta_src ? (ldm / 4) : 0; // Number of metadata bytes per row (K/4 blocks)
199+
192200
detail::unroll_for<Frag::NR>([&](auto r) {
193201
uint32_t block_m = r / cfg::k_steps;
194202
uint32_t block_k = r % cfg::k_steps;
195203
uint32_t elem_row = block_m * m_stride;
196204
uint32_t elem_col = block_k * k_stride;
205+
uint32_t meta_value = 0;
206+
207+
if (meta_base) {
208+
uint32_t matrix_row = block_row + elem_row;
209+
uint32_t k_elem_idx = elem_col / i_ratio;
210+
uint32_t meta_block_k = k_elem_idx / 4;
211+
if (meta_block_k < meta_ldm) {
212+
uint32_t meta_offset = matrix_row * meta_ldm + meta_block_k;
213+
meta_value = static_cast<uint32_t>(meta_base[meta_offset]);
214+
}
215+
}
216+
217+
if constexpr (Frag::Use == matrix_a) {
218+
dst.metadata[r] = meta_value;
219+
}
197220
if constexpr (src_layout == col_major) {
198221
static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a");
199222
std::swap(elem_row, elem_col);
200-
auto ptr = base + elem_row * ldm + elem_col;
223+
auto ptr = base + elem_row * data_ldm + elem_col;
201224
if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) {
202225
dst.data[r] = *reinterpret_cast<const vreg_t*>(ptr);
203226
} else {
204-
dst.data[r] = input_acessor_t::pack_row(ptr, ldm);
227+
dst.data[r] = input_acessor_t::pack_row(ptr, data_ldm);
205228
}
206229
} else {
207-
// raw_major layout
208-
auto ptr = base + elem_row * ldm + elem_col;
230+
// row_major layout
231+
// For sparse format, use data_ldm (K/2) instead of ldm (K)
232+
auto ptr = base + elem_row * data_ldm + elem_col;
209233
assert(reinterpret_cast<uintptr_t>(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes");
210234
dst.data[r] = *reinterpret_cast<const vreg_t *>(ptr);
211235
}
@@ -310,6 +334,24 @@ struct wmma_context {
310334
static_assert(FragC::Use == accumulator, "C must be accumulator");
311335
static_assert(FragD::Use == accumulator, "D must be accumulator");
312336

337+
auto meta_value = [&](uint32_t idx) -> uint32_t {
338+
if constexpr (FragA::Use == matrix_a) {
339+
if (idx < FragA::NR) {
340+
return fragA.metadata[idx];
341+
}
342+
}
343+
return 0u;
344+
};
345+
346+
register uint32_t ma0 __asm__("a0") = meta_value(0);
347+
register uint32_t ma1 __asm__("a1") = meta_value(1);
348+
register uint32_t ma2 __asm__("a2") = meta_value(2);
349+
register uint32_t ma3 __asm__("a3") = meta_value(3);
350+
register uint32_t ma4 __asm__("a4") = meta_value(4);
351+
register uint32_t ma5 __asm__("a5") = meta_value(5);
352+
register uint32_t ma6 __asm__("a6") = meta_value(6);
353+
register uint32_t ma7 __asm__("a7") = meta_value(7);
354+
313355
// fragA: caller-saved registers (f0-f7)
314356
register float fa0 __asm__("f0") = fragA.data[0];
315357
register float fa1 __asm__("f1") = fragA.data[1];
@@ -348,15 +390,16 @@ struct wmma_context {
348390
register float fd3 __asm__("f27");
349391
register float fd4 __asm__("f28");
350392
register float fd5 __asm__("f29");
351-
register float fd6 __asm__("f30");
393+
register float fd6 __asm__("f30");
352394
register float fd7 __asm__("f31");
353395

354-
__asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x0"
396+
__asm__ volatile (".insn r %[insn], 0, 3, x%[fmd], x%[fms], x0"
355397
: "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7)
356398
: [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id),
357399
"f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7),
358400
"f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fb4), "f"(fb5), "f"(fb6), "f"(fb7),
359-
"f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7)
401+
"f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7),
402+
"r"(ma0), "r"(ma1), "r"(ma2), "r"(ma3), "r"(ma4), "r"(ma5), "r"(ma6), "r"(ma7)
360403
);
361404

362405
// Write results to fragD
@@ -389,12 +432,13 @@ struct wmma_context {
389432
register float fd6 __asm__("f16");
390433
register float fd7 __asm__("f17");
391434

392-
__asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x0"
435+
__asm__ volatile (".insn r %[insn], 0, 3, x%[fmd], x%[fms], x0"
393436
: "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7)
394437
: [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id),
395438
"f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7),
396439
"f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3),
397-
"f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7)
440+
"f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7),
441+
"r"(ma0), "r"(ma1), "r"(ma2), "r"(ma3), "r"(ma4), "r"(ma5), "r"(ma6), "r"(ma7)
398442
);
399443

400444
// Write results to fragD

sim/simx/core.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@
2323
#include "debug.h"
2424
#include "constants.h"
2525

26+
#ifdef EXT_VEGETA_ENABLE
27+
#ifndef NUM_VEGETA_BLOCKS
28+
#ifdef NUM_TCU_BLOCKS
29+
#define NUM_VEGETA_BLOCKS NUM_TCU_BLOCKS
30+
#else
31+
#define NUM_VEGETA_BLOCKS ISSUE_WIDTH
32+
#endif
33+
#endif
34+
#ifndef NUM_VEGETA_LANES
35+
#ifdef NUM_TCU_LANES
36+
#define NUM_VEGETA_LANES NUM_TCU_LANES
37+
#else
38+
#define NUM_VEGETA_LANES NUM_THREADS
39+
#endif
40+
#endif
41+
#endif
42+
2643
using namespace vortex;
2744

2845
Core::Core(const SimContext& ctx,
@@ -44,6 +61,9 @@ Core::Core(const SimContext& ctx,
4461
#endif
4562
#ifdef EXT_V_ENABLE
4663
, vec_unit_(VecUnit::Create("vpu", arch, this))
64+
#endif
65+
#ifdef EXT_VEGETA_ENABLE
66+
, sparse_unit_(SparseUnit::Create("spu", arch, this))
4767
#endif
4868
, emulator_(arch, dcrs, this)
4969
, ibuffers_(arch.num_warps(), IBUF_SIZE)
@@ -133,7 +153,7 @@ Core::Core(const SimContext& ctx,
133153
dcache_rsp_ports.at(p).bind(&lsu_dcache_adapter.at(b)->RspOut.at(c));
134154
}
135155
}
136-
156+
137157
// initialize dispatchers
138158
dispatchers_.at((int)FUType::ALU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_ALU_BLOCKS, NUM_ALU_LANES);
139159
dispatchers_.at((int)FUType::FPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_FPU_BLOCKS, NUM_FPU_LANES);
@@ -145,6 +165,9 @@ Core::Core(const SimContext& ctx,
145165
#ifdef EXT_TCU_ENABLE
146166
dispatchers_.at((int)FUType::TCU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_TCU_BLOCKS, NUM_TCU_LANES);
147167
#endif
168+
#ifdef EXT_VEGETA_ENABLE
169+
dispatchers_.at((int)FUType::VEGETA) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_VEGETA_BLOCKS, NUM_VEGETA_LANES);
170+
#endif
148171

149172
// initialize execute units
150173
func_units_.at((int)FUType::ALU) = SimPlatform::instance().create_object<AluUnit>(this);
@@ -157,7 +180,9 @@ Core::Core(const SimContext& ctx,
157180
#ifdef EXT_TCU_ENABLE
158181
func_units_.at((int)FUType::TCU) = SimPlatform::instance().create_object<TcuUnit>(this);
159182
#endif
160-
183+
#ifdef EXT_VEGETA_ENABLE
184+
func_units_.at((int)FUType::VEGETA) = SimPlatform::instance().create_object<VegetaUnit>(this);
185+
#endif
161186
// bind commit arbiters
162187
for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) {
163188
snprintf(sname, 100, "%s-commit-arb%d", this->name().c_str(), iw);
@@ -223,7 +248,6 @@ void Core::schedule() {
223248

224249
// suspend warp until decode
225250
emulator_.suspend(trace->wid);
226-
227251
DT(3, "pipeline-schedule: " << *trace);
228252

229253
// advance to fetch stage

sim/simx/core.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
#else
3030
#include "operands.h"
3131
#endif
32+
#ifdef EXT_VEGETA_ENABLE
33+
#include "sparse_unit.h"
34+
#endif
3235

3336
#include "dispatcher.h"
3437
#include "func_unit.h"
@@ -171,6 +174,12 @@ class Core : public SimObject<Core> {
171174
}
172175
#endif
173176

177+
#ifdef EXT_VEGETA_ENABLE
178+
SparseUnit::Ptr& sparse_unit() {
179+
return sparse_unit_;
180+
}
181+
#endif
182+
174183
auto& trace_pool() {
175184
return trace_pool_;
176185
}
@@ -200,6 +209,10 @@ class Core : public SimObject<Core> {
200209
VecUnit::Ptr vec_unit_;
201210
#endif
202211

212+
#ifdef EXT_VEGETA_ENABLE
213+
SparseUnit::Ptr sparse_unit_;
214+
#endif
215+
203216
Emulator emulator_;
204217

205218
std::vector<IBuffer> ibuffers_;

0 commit comments

Comments
 (0)