@@ -29,7 +29,7 @@ __pack_half2(const half x, const half y) {
2929
3030__global__ void __launch_bounds__ (64 ) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int * __restrict__ B, half* __restrict__ scaling_factors, int * __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
3131{
32- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
32+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
3333 assert (false );
3434#else
3535 static constexpr uint32_t ZERO = 0x0 ;
@@ -191,6 +191,39 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
191191 }
192192 }
193193 for (int j_0_4 = 0 ; j_0_4 < 4 ; ++j_0_4) {
194+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
195+ {
196+ __asm__ __volatile__ (
197+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
198+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
199+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
200+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
201+ }
202+
203+ {
204+ __asm__ __volatile__ (
205+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
206+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
207+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
208+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
209+ }
210+
211+ {
212+ __asm__ __volatile__ (
213+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
214+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
215+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
216+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
217+ }
218+
219+ {
220+ __asm__ __volatile__ (
221+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
222+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
223+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
224+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
225+ }
226+ #else
194227 {
195228 __asm__ __volatile__ (
196229 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
@@ -206,6 +239,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
206239 : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
207240 : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
208241 }
242+
243+ #endif
209244 }
210245 }
211246 }
@@ -226,7 +261,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
226261
227262__global__ void __launch_bounds__ (64 ) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int * __restrict__ B, half* __restrict__ scaling_factors, int * __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
228263{
229- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
264+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
230265 assert (false );
231266#else
232267 static constexpr uint32_t ZERO = 0x0 ;
@@ -392,7 +427,39 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
392427
393428 for (int j_0_4 = 0 ; j_0_4 < 2 ; ++j_0_4)
394429 {
430+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
431+ {
432+ __asm__ __volatile__ (
433+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
434+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
435+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
436+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
437+ }
395438
439+ {
440+ __asm__ __volatile__ (
441+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
442+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
443+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
444+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
445+ }
446+
447+ {
448+ __asm__ __volatile__ (
449+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
450+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
451+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
452+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
453+ }
454+
455+ {
456+ __asm__ __volatile__ (
457+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
458+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
459+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
460+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
461+ }
462+ #else
396463 {
397464 __asm__ __volatile__ (
398465 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
@@ -408,6 +475,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
408475 : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
409476 : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
410477 }
478+ #endif
411479 }
412480 }
413481 }
0 commit comments