You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: hgemm/hgemm.cu
+9-2Lines changed: 9 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -999,7 +999,8 @@ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor
999
999
voidhgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1000
1000
voidhgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1001
1001
// from hgemm_cublas.cu
1002
-
voidhgemm_cublas_tensor_op_row_major(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1002
+
voidhgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1003
+
voidhgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1003
1004
// from hgemm_wmma.cu
1004
1005
voidhgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1005
1006
voidhgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
@@ -1018,6 +1019,9 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::
1018
1019
voidhgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1019
1020
voidhgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1020
1021
voidhgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1022
+
// from hgemm_mma_stage_tn.cu
1023
+
voidhgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1024
+
1021
1025
1022
1026
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1023
1027
// CUDA Cores FP16
@@ -1037,7 +1041,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
0 commit comments