Skip to content

Commit 6d21fa1

Browse files
authored
[Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5) (#5136)
1 parent b35be54 commit 6d21fa1

File tree

1 file changed

+8
-4
lines changed
  • csrc/quantization/marlin/sparse/common

1 file changed

+8
-4
lines changed

csrc/quantization/marlin/sparse/common/mma.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
3232
float* c = reinterpret_cast<float*>(&frag_c);
3333
if (psel == 0) {
3434
asm volatile(
35-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
35+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
36+
"f32 "
3637
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
3738
"{%12,%13,%14,%15}, %16, 0x0;\n"
3839
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
3940
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
4041
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
4142
"r"(e[0]));
4243
asm volatile(
43-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
44+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
45+
"f32 "
4446
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
4547
"{%12,%13,%14,%15}, %16, 0x0;\n"
4648
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
@@ -49,15 +51,17 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
4951
"r"(e[0]));
5052
} else {
5153
asm volatile(
52-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
54+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
55+
"f32 "
5356
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
5457
"{%12,%13,%14,%15}, %16, 0x1;\n"
5558
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
5659
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
5760
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
5861
"r"(e[0]));
5962
asm volatile(
60-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
63+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
64+
"f32 "
6165
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
6266
"{%12,%13,%14,%15}, %16, 0x1;\n"
6367
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])

0 commit comments

Comments
 (0)