You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`sky blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [hgemm benchmark](./hgemm) for more details.
28
+
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [hgemm benchmark](./hgemm) for more details.
29
29
30
30
|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
31
31
|:---:|:---:|:---:|:---:|
@@ -36,8 +36,8 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's d
Copy file name to clipboardExpand all lines: hgemm/hgemm.cu
+8-1Lines changed: 8 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1023,7 +1023,10 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch:
1023
1023
voidhgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1024
1024
// from hgemm_mma_stage_tn.cu
1025
1025
voidhgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1026
-
1026
+
#ifdef ENBLE_CUTE_HGEMM
1027
+
// from hgemm_mma_stage_tn_cute.cu
1028
+
voidhgemm_mma_stages_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1029
+
#endif
1027
1030
1028
1031
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1029
1032
// CUDA Cores FP16
@@ -1067,5 +1070,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
0 commit comments