Skip to content

Commit 7c71d86

Browse files
committed
Merge branch 'bug_fixes' of https://github.com/vortexgpgpu/vortex into bug_fixes
2 parents 6255078 + 8c19f62 commit 7c71d86

File tree

12 files changed

+548
-33
lines changed

12 files changed

+548
-33
lines changed

ci/regression.sh.in

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,9 @@ tensor()
435435
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=uint4 -DOTYPE=int32" make -C tests/regression/sgemm_tcu
436436
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE" ./ci/blackbox.sh --driver=simx --app=sgemm_tcu
437437

438+
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=mxint8 -DOTYPE=int32" make -C tests/regression/sgemm_tcu
439+
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE" ./ci/blackbox.sh --driver=simx --app=sgemm_tcu
440+
438441
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=8 -DITYPE=fp16 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
439442
CONFIGS="-DNUM_THREADS=8 -DEXT_TCU_ENABLE -DISSUE_WIDTH=2" ./ci/blackbox.sh --driver=simx --app=sgemm_tcu
440443

@@ -447,13 +450,22 @@ tensor()
447450
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=16 -DITYPE=bf8 -DOTYPE=bf8" make -C tests/regression/sgemm_tcu
448451
CONFIGS="-DNUM_THREADS=16 -DEXT_TCU_ENABLE" ./ci/blackbox.sh --driver=simx --app=sgemm_tcu
449452

453+
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=mxfp8 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
454+
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE" ./ci/blackbox.sh --driver=simx --app=sgemm_tcu
455+
456+
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=nvfp4 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
457+
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE" ./ci/blackbox.sh --driver=simx --app=sgemm_tcu
458+
450459
# rtlsim tests
451460
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=int8 -DOTYPE=int32" make -C tests/regression/sgemm_tcu
452461
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
453462

454463
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=uint4 -DOTYPE=int32" make -C tests/regression/sgemm_tcu
455464
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
456465

466+
#make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=mxint8 -DOTYPE=int32" make -C tests/regression/sgemm_tcu
467+
#CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
468+
457469
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=fp16 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
458470
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
459471

@@ -466,6 +478,12 @@ tensor()
466478
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=bf8 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
467479
CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
468480

481+
#make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=mxfp8 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
482+
#CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
483+
484+
#make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=4 -DITYPE=nvfp4 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
485+
#CONFIGS="-DNUM_THREADS=4 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
486+
469487
make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=8 -DITYPE=fp16 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
470488
CONFIGS="-DNUM_THREADS=8 -DEXT_TCU_ENABLE -DTCU_TYPE_DPI -DISSUE_WIDTH=2" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
471489

@@ -476,7 +494,7 @@ tensor()
476494
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" NUM_REGS=1 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=1 --no-fused
477495
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_BHF" NUM_REGS=1 LATENCY=10 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=1 --no-fused
478496
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DSP" NUM_REGS=1 LATENCY=31 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=1 --no-fused --ulp=3
479-
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL -DUSE_FEDP" NUM_REGS=2 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=1 --ulp=2
497+
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL -DUSE_FEDP" NUM_REGS=2 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=1 --no-zeros --no-subnormals --no-infinities --no-nans
480498

481499
# test bf16
482500
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" NUM_REGS=1 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=2 --no-fused
@@ -494,6 +512,16 @@ tensor()
494512
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_BHF" NUM_REGS=1 LATENCY=10 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=4 --no-fused --ulp=4
495513
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL -DUSE_FEDP" NUM_REGS=2 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=4 --no-zeros --no-subnormals --no-infinities --no-nans
496514

515+
# test mxfp8
516+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" NUM_REGS=1 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=5 --no-fused
517+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_BHF" NUM_REGS=1 LATENCY=10 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=5 --no-fused
518+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL -DUSE_FEDP" NUM_REGS=2 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=5
519+
520+
# test nvfp4
521+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" NUM_REGS=1 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=7 --no-fused
522+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_BHF" NUM_REGS=1 LATENCY=10 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=7 --no-fused
523+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL -DUSE_FEDP" NUM_REGS=2 LATENCY=4 make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=7
524+
497525
# test int8
498526
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=9
499527
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL" make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=9
@@ -510,6 +538,10 @@ tensor()
510538
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=12
511539
make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL" make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=12
512540

541+
# test mxint8
542+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DPI" make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=13
543+
#make -C hw/unittest/tcu_fedp clean && CONFIGS="-DTCU_TYPE_DRL" make -C hw/unittest/tcu_fedp && hw/unittest/tcu_fedp/tcu_fedp --fmt=13
544+
513545
echo "tensor tests done!"
514546
}
515547

hw/rtl/tcu/drl/VX_tcu_fedp_drl.sv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
module VX_tcu_fedp_drl #(
1717
parameter LATENCY = 0,
1818
parameter N = 2,
19-
parameter W = 53
19+
parameter W = 25
2020
) (
2121
input wire clk,
2222
input wire reset,

hw/unittest/tcu_fedp/fedp.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <unordered_map>
2323
#include <vector>
2424

25-
#if FEDP_TRACE
25+
#ifdef FEDP_TRACE
2626
#include <cstdio>
2727
#define LOG(...) std::fprintf(stderr, __VA_ARGS__);
2828
#else

hw/unittest/tcu_fedp/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
#include <bitmanip.h>
4444
#include "softfloat_ext.h"
4545

46-
#ifdef FEDP_EMUL
46+
#ifdef USE_FEDP
4747
#include "fedp.h"
4848
#endif
4949

@@ -280,7 +280,7 @@ static void pack_elements(const std::vector<uint32_t> &elements, int element_bit
280280
}
281281
}
282282

283-
#ifndef FEDP_EMUL
283+
#ifndef USE_FEDP
284284
// Calculate expected fp dot product
285285
static float dot_product(const uint32_t* A, const uint32_t* B, uint32_t C, int n, int eb, int sb, bool fused) {
286286
auto to_float = [&](uint32_t x, int ebits, int sbits) -> long double {
@@ -679,7 +679,7 @@ class Testbench {
679679
const uint32_t NF = features_to_test.size();
680680
const uint32_t tests_per_feature = (NT + NF - 1) / NF;
681681

682-
#ifdef FEDP_EMUL
682+
#ifdef USE_FEDP
683683
FEDP fedp(config_.exp_bits, config_.sig_bits, NUM_REGS * 2, (int)config_.frm, config_.W, config_.renorm);
684684
#endif
685685

@@ -746,7 +746,7 @@ class Testbench {
746746
std::memcpy(&dut_result, &dut_result_bits, sizeof(float));
747747

748748
// Calculate expected result
749-
#ifdef FEDP_EMUL
749+
#ifdef USE_FEDP
750750
float expected = fedp(a_packed.data(), b_packed.data(), c_value_float, NUM_REGS);
751751
#else
752752
float expected = dot_product(a_value_hex.data(), b_value_hex.data(),

sim/common/rvfloats.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,48 @@ uint8_t rv_ftoe5m2_s(uint32_t a, uint32_t frm, uint32_t* fflags) {
597597
return bf8.v;
598598
}
599599

600+
uint32_t rv_mxfp8tof_s(uint8_t a, uint8_t sf, uint32_t frm, uint32_t* fflags) {
601+
rv_init(frm);
602+
mxfloat8_t mxfp8;
603+
mxfp8.v = a;
604+
mxfp8.sf = sf;
605+
float32_t f32 = mxfp8_to_f32(mxfp8);
606+
if (fflags) { *fflags = softfloat_exceptionFlags; }
607+
return f32.v;
608+
}
609+
610+
uint8_t rv_ftomxfp8_s(uint32_t a, uint8_t sf, uint32_t frm, uint32_t* fflags) {
611+
rv_init(frm);
612+
float32_t f32;
613+
f32.v = a;
614+
sfexp8_t scale_factor;
615+
scale_factor.sf = sf;
616+
mxfloat8_t mxfp8 = f32_to_mxfp8(f32, scale_factor);
617+
if (fflags) { *fflags = softfloat_exceptionFlags; }
618+
return mxfp8.v;
619+
}
620+
621+
uint32_t rv_nvfp4tof_s(uint8_t a, uint8_t sf, uint32_t frm, uint32_t* fflags) {
622+
rv_init(frm);
623+
nvfloat4_t nvfp4;
624+
nvfp4.v = a;
625+
nvfp4.sf = sf;
626+
float32_t f32 = nvfp4_to_f32(nvfp4);
627+
if (fflags) { *fflags = softfloat_exceptionFlags; }
628+
return f32.v;
629+
}
630+
631+
uint8_t rv_ftonvfp4_s(uint32_t a, uint8_t sf, uint32_t frm, uint32_t* fflags) {
632+
rv_init(frm);
633+
float32_t f32;
634+
f32.v = a;
635+
sffloat8_t scale_factor;
636+
scale_factor.sf = sf;
637+
nvfloat4_t nvfp4 = f32_to_nvfp4(f32, scale_factor);
638+
if (fflags) { *fflags = softfloat_exceptionFlags; }
639+
return nvfp4.v;
640+
}
641+
600642
#ifdef __cplusplus
601643
}
602644
#endif

sim/common/rvfloats.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ uint8_t rv_ftoe4m3_s(uint32_t a, uint32_t frm, uint32_t* fflags);
107107
uint32_t rv_e5m2tof_s(uint8_t a, uint32_t frm, uint32_t* fflags);
108108
uint8_t rv_ftoe5m2_s(uint32_t a, uint32_t frm, uint32_t* fflags);
109109

110+
// mxfp8 <--> fp32 conversions
111+
uint32_t rv_mxfp8tof_s(uint8_t a, uint8_t sf, uint32_t frm, uint32_t* fflags);
112+
uint8_t rv_ftomxfp8_s(uint32_t a, uint8_t sf, uint32_t frm, uint32_t* fflags);
113+
114+
// nvfp4 <--> fp32 conversions
115+
uint32_t rv_nvfp4tof_s(uint8_t a, uint8_t sf, uint32_t frm, uint32_t* fflags);
116+
uint8_t rv_ftonvfp4_s(uint32_t a, uint8_t sf, uint32_t frm, uint32_t* fflags);
117+
110118
#ifdef __cplusplus
111119
}
112120
#endif

sim/common/softfloat_ext.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,64 @@ bfloat8_t f32_to_f8e5m2(float32_t a) {
905905
return res;
906906
}
907907

908+
float32_t mxfp8_to_f32(mxfloat8_t a) {
909+
//convert e4m3 value to f32
910+
uint32_t fflags = 0;
911+
auto base_value = cvt_custom_to_f32(a.v, 4, 3, softfloat_roundingMode, &fflags);
912+
//convert e8m0 scale factor to f32 (bias = 127)
913+
int32_t scale_exp = (int32_t)a.sf - 127;
914+
float scale_factor = std::ldexp(1.0f, scale_exp);
915+
float out = base_value * scale_factor;
916+
softfloat_exceptionFlags |= fflags;
917+
float32_t res;
918+
res.v = vortex::bit_cast<uint32_t>(out);
919+
return res;
920+
}
921+
922+
mxfloat8_t f32_to_mxfp8(float32_t a, sfexp8_t scale_factor) {
923+
//extract e8m0 scale factor
924+
int32_t scale_exp = (int32_t)scale_factor.sf - 127;
925+
float scale = std::ldexp(1.0f, scale_exp);
926+
//divide input by scale factor
927+
float scaled_value = vortex::bit_cast<float>(a.v) / scale;
928+
//convert scaled value to e4m3
929+
uint32_t fflags = 0;
930+
auto out = cvt_f32_to_custom(scaled_value, 4, 3, softfloat_roundingMode, &fflags);
931+
softfloat_exceptionFlags |= fflags;
932+
mxfloat8_t res;
933+
res.v = out & 0xff;
934+
res.sf = scale_factor.sf;
935+
return res;
936+
}
937+
938+
float32_t nvfp4_to_f32(nvfloat4_t a) {
939+
//convert e2m1 value to f32
940+
uint32_t fflags = 0;
941+
auto base_value = cvt_custom_to_f32(a.v, 2, 1, softfloat_roundingMode, &fflags);
942+
//convert e4m3 scale factor to f32
943+
auto scale_factor = cvt_custom_to_f32(a.sf, 4, 3, softfloat_roundingMode, &fflags);
944+
float out = base_value * scale_factor;
945+
softfloat_exceptionFlags |= fflags;
946+
float32_t res;
947+
res.v = vortex::bit_cast<uint32_t>(out);
948+
return res;
949+
}
950+
951+
nvfloat4_t f32_to_nvfp4(float32_t a, sffloat8_t scale_factor) {
952+
//extract e4m3 scale factor
953+
uint32_t fflags = 0;
954+
float scale = cvt_custom_to_f32(scale_factor.sf, 4, 3, softfloat_roundingMode, &fflags);
955+
//divide input by scale factor
956+
float scaled_value = vortex::bit_cast<float>(a.v) / scale;
957+
//conver scaled value to e2m1
958+
auto out = cvt_f32_to_custom(scaled_value, 2, 1, softfloat_roundingMode, &fflags);
959+
softfloat_exceptionFlags |= fflags;
960+
nvfloat4_t res;
961+
res.v = out & 0x0f;
962+
res.sf = scale_factor.sf;
963+
return res;
964+
}
965+
908966
#ifdef __cplusplus
909967
}
910968
#endif

sim/common/softfloat_ext.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ extern "C" {
77

88
typedef struct { uint8_t v; } float8_t; // e4m3
99
typedef struct { uint8_t v; } bfloat8_t; // e5m2
10+
typedef struct { uint8_t v, sf; } mxfloat8_t; // e4m3 with e8m0 scale
11+
typedef struct { uint8_t v, sf; } nvfloat4_t; // e2m1 with e4m3 scale
12+
typedef struct { uint8_t sf; } sfexp8_t; // e8m0 scale factor
13+
typedef struct { uint8_t sf; } sffloat8_t; // e4m3 scale factor
1014

1115
uint_fast16_t f16_classify(float16_t);
1216
float16_t f16_rsqrte7(float16_t);
@@ -26,6 +30,12 @@ float32_t f8e4m3_to_f32(float8_t);
2630
bfloat8_t f32_to_f8e5m2(float32_t);
2731
float32_t f8e5m2_to_f32(bfloat8_t);
2832

33+
mxfloat8_t f32_to_mxfp8(float32_t, sfexp8_t);
34+
float32_t mxfp8_to_f32(mxfloat8_t);
35+
36+
nvfloat4_t f32_to_nvfp4(float32_t, sffloat8_t);
37+
float32_t nvfp4_to_f32(nvfloat4_t);
38+
2939
uint32_t cvt_f32_to_custom(float value, uint32_t exp_bits, uint32_t sig_bits,
3040
uint32_t frm, uint32_t *fflags);
3141

sim/common/tensor_cfg.h

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,38 @@ struct bf16 {
4343
static constexpr const char* name = "bf16";
4444
};
4545

46-
// e4m3 (use for forward pass)
4746
struct fp8 {
4847
using dtype = uint8_t;
4948
static constexpr uint32_t id = 3;
5049
static constexpr uint32_t bits = 8;
5150
static constexpr const char* name = "fp8";
5251
};
5352

54-
// e5m2 (use for backprop)
5553
struct bf8 {
5654
using dtype = uint8_t;
5755
static constexpr uint32_t id = 4;
5856
static constexpr uint32_t bits = 8;
5957
static constexpr const char* name = "bf8";
6058
};
6159

60+
struct mxfp8 {
61+
using dtype = uint8_t;
62+
static constexpr uint32_t id = 5;
63+
static constexpr uint32_t bits = 8;
64+
static constexpr uint32_t scale_bits = 8;
65+
static constexpr uint32_t ele_block = 32; //elements per block
66+
static constexpr const char* name = "mxfp8";
67+
};
68+
69+
struct nvfp4 {
70+
using dtype = uint8_t;
71+
static constexpr uint32_t id = 7;
72+
static constexpr uint32_t bits = 4;
73+
static constexpr uint32_t scale_bits = 8;
74+
static constexpr uint32_t ele_block = 16;
75+
static constexpr const char* name = "nvfp4";
76+
};
77+
6278
struct int32 {
6379
using dtype = int32_t;
6480
static constexpr uint32_t id = 8;
@@ -94,19 +110,31 @@ struct uint4 {
94110
static constexpr const char* name = "u4";
95111
};
96112

113+
struct mxint8 {
114+
using dtype = int8_t;
115+
static constexpr uint32_t id = 13;
116+
static constexpr uint32_t bits = 8;
117+
static constexpr uint32_t scale_bits = 8;
118+
static constexpr uint32_t ele_blcok = 32;
119+
static constexpr const char* name = "mxi8";
120+
};
121+
97122
inline const char* fmt_string(uint32_t fmt) {
98123
switch (fmt) {
99-
case fp32::id: return fp32::name;
100-
case fp16::id: return fp16::name;
101-
case bf16::id: return bf16::name;
102-
case fp8::id: return fp8::name;
103-
case bf8::id: return bf8::name;
104-
case int32::id: return int32::name;
105-
case int8::id: return int8::name;
106-
case uint8::id: return uint8::name;
107-
case int4::id: return int4::name;
108-
case uint4::id: return uint4::name;
109-
default: return "";
124+
case fp32::id: return fp32::name;
125+
case fp16::id: return fp16::name;
126+
case bf16::id: return bf16::name;
127+
case fp8::id: return fp8::name;
128+
case bf8::id: return bf8::name;
129+
case mxfp8::id: return mxfp8::name;
130+
case nvfp4::id: return nvfp4::name;
131+
case int32::id: return int32::name;
132+
case int8::id: return int8::name;
133+
case uint8::id: return uint8::name;
134+
case int4::id: return int4::name;
135+
case uint4::id: return uint4::name;
136+
case mxint8::id: return mxint8::name;
137+
default: return "";
110138
}
111139
}
112140

0 commit comments

Comments
 (0)