@@ -730,8 +730,20 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
730
730
(0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
731
731
lane_smem_b_n) * sizeof (half)
732
732
);
733
+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
733
734
LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
734
735
lane_smem_b_ptr);
736
+ // int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
737
+ // int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
738
+ // uint32_t lane_smem_b_ptr = (
739
+ // smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
740
+ // (0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
741
+ // lane_smem_b_n) * sizeof(half)
742
+ // );
743
+ // // TRICK: I use .x4.trans to load 4 matrix for reg double buffers at once.
744
+ // LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
745
+ // RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
746
+ // lane_smem_b_ptr);
735
747
}
736
748
}
737
749
@@ -805,6 +817,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
805
817
(smem_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
806
818
lane_smem_b_n) * sizeof (half)
807
819
);
820
+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
808
821
LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
809
822
lane_smem_b_ptr);
810
823
}
@@ -841,7 +854,6 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
841
854
}
842
855
}
843
856
844
-
845
857
CP_ASYNC_WAIT_GROUP (K_STAGE-2 );
846
858
__syncthreads ();
847
859
@@ -874,8 +886,20 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
874
886
lane_smem_b_k * (BN + B_PAD) +
875
887
lane_smem_b_n) * sizeof (half)
876
888
);
889
+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
877
890
LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
878
891
lane_smem_b_ptr);
892
+ // int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
893
+ // int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
894
+ // uint32_t lane_smem_b_ptr = (
895
+ // smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
896
+ // (smem_sel_reg * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
897
+ // lane_smem_b_n) * sizeof(half)
898
+ // );
899
+ // // may use .x4.trans to load 4 matrix for reg double buffers at once?
900
+ // LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
901
+ // RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
902
+ // lane_smem_b_ptr);
879
903
}
880
904
}
881
905
@@ -920,6 +944,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
920
944
(stage_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
921
945
lane_smem_b_n) * sizeof (half)
922
946
);
947
+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
923
948
LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
924
949
lane_smem_b_ptr);
925
950
}
@@ -988,6 +1013,17 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
988
1013
);
989
1014
LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
990
1015
lane_smem_b_ptr);
1016
+ // int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
1017
+ // int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
1018
+ // uint32_t lane_smem_b_ptr = (
1019
+ // smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
1020
+ // (stage_sel_reg * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
1021
+ // lane_smem_b_n) * sizeof(half)
1022
+ // );
1023
+ // // may use .x4.trans to load 4 matrix for reg double buffers at once?
1024
+ // LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
1025
+ // RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
1026
+ // lane_smem_b_ptr);
991
1027
}
992
1028
}
993
1029
}
0 commit comments