Skip to content

Commit bcd12bd

Browse files
authored
[HGEMM] ldmatrix.x4.trans with reg double buffers (#100)
* Update hgemm_mma_stage.cu * Update hgemm_mma_stage.cu * Update hgemm.py
1 parent 8e869ef commit bcd12bd

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

hgemm/hgemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def run_benchmark(perf_func: callable,
217217
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True)
218218
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
219219
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
220-
if args.enable_mma_all:
221220
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
222221
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
223222
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)

hgemm/hgemm_mma_stage.cu

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,20 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
730730
(0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
731731
lane_smem_b_n) * sizeof(half)
732732
);
733+
// TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
733734
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
734735
lane_smem_b_ptr);
736+
// int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
737+
// int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
738+
// uint32_t lane_smem_b_ptr = (
739+
// smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
740+
// (0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
741+
// lane_smem_b_n) * sizeof(half)
742+
// );
743+
// // TRICK: I use .x4.trans to load 4 matrix for reg double buffers at once.
744+
// LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
745+
// RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
746+
// lane_smem_b_ptr);
735747
}
736748
}
737749

@@ -805,6 +817,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
805817
(smem_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
806818
lane_smem_b_n) * sizeof(half)
807819
);
820+
// TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
808821
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
809822
lane_smem_b_ptr);
810823
}
@@ -841,7 +854,6 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
841854
}
842855
}
843856

844-
845857
CP_ASYNC_WAIT_GROUP(K_STAGE-2);
846858
__syncthreads();
847859

@@ -874,8 +886,20 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
874886
lane_smem_b_k * (BN + B_PAD) +
875887
lane_smem_b_n) * sizeof(half)
876888
);
889+
// TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
877890
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
878891
lane_smem_b_ptr);
892+
// int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
893+
// int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
894+
// uint32_t lane_smem_b_ptr = (
895+
// smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
896+
// (smem_sel_reg * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
897+
// lane_smem_b_n) * sizeof(half)
898+
// );
899+
// // may use .x4.trans to load 4 matrix for reg double buffers at once?
900+
// LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
901+
// RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
902+
// lane_smem_b_ptr);
879903
}
880904
}
881905

@@ -920,6 +944,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
920944
(stage_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
921945
lane_smem_b_n) * sizeof(half)
922946
);
947+
// TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
923948
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
924949
lane_smem_b_ptr);
925950
}
@@ -988,6 +1013,17 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
9881013
);
9891014
LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
9901015
lane_smem_b_ptr);
1016+
// int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
1017+
// int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
1018+
// uint32_t lane_smem_b_ptr = (
1019+
// smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
1020+
// (stage_sel_reg * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
1021+
// lane_smem_b_n) * sizeof(half)
1022+
// );
1023+
// // may use .x4.trans to load 4 matrix for reg double buffers at once?
1024+
// LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
1025+
// RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
1026+
// lane_smem_b_ptr);
9911027
}
9921028
}
9931029
}

0 commit comments

Comments
 (0)