@@ -659,18 +659,35 @@ __global__ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_kernel(
659659 LDST64BITS (r_load_a[0 ]) = LDST64BITS (a[load_a_gmem_addr]);
660660 LDST64BITS (r_load_b[0 ]) = LDST64BITS (b[load_b_gmem_addr]);
661661
662+ // s_a[8][128] write: 4路 bank conflicts
662663 s_a[load_a_smem_k ][load_a_smem_m] = r_load_a[0 ];
663664 s_a[load_a_smem_k + 1 ][load_a_smem_m] = r_load_a[1 ];
664665 s_a[load_a_smem_k + 2 ][load_a_smem_m] = r_load_a[2 ];
665666 s_a[load_a_smem_k + 3 ][load_a_smem_m] = r_load_a[3 ];
667+ // s_b[8][128] write: 2路 bank conflicts
666668 LDST64BITS (s_b[load_b_smem_k][load_b_smem_n]) = LDST64BITS (r_load_b[0 ]);
667669
668670 __syncthreads ();
669671
670672 #pragma unroll
671673 for (int tk = 0 ; tk < BK; tk++) {
674+ // bank conflicts analysis, tx/ty 0~15, 0~7 bank 4*8=32 bytes
675+ // 进入具体线程后,可以认为该线程对应的值都已经固定了,比如tid, tx, ty.
676+ // 因此对于这个循环的理解,应该按照tk迭代,tid, tx, ty固定为某个值来理解.
677+ // 但是分析bank conflicts需要考虑warp内线程的并发行为,因此,应该分析
678+ // 不同线程在同一个时间点的bank访存情况.
679+ // s_a[8][128] load: 16路 bank conflicts
680+ // tid 0~15, tk 0~7 -> ty 0 -> [0~7][0+0~7] bank 0~3 layers_0~15
681+ // tid 16~31, tk 0~7 -> ty 1 -> [0~7][0+8~15] bank 4~7 layers_0~15
672682 LDST128BITS (r_comp_a[0 ]) = LDST128BITS (s_a[tk][ty * TM]);
683+ // s_b[8][128] load: 4路 bank conflicts
684+ // tid 0, tk 0~7 -> tx 0 -> [0~7][0+0~7] bank 0~3
685+ // tid 1, tk 0~7 -> tx 1 -> [0~7][0+8~15] bank 4~7
686+ // tid 7, tk 0~7 -> tx 7 -> [0~7][0+56~63] bank 28~31
687+ // tid 0/8/16/24, tk 0~7 -> tx 0/8/16/24 -> [0~7][0+...] bank 0~3
673688 LDST128BITS (r_comp_b[0 ]) = LDST128BITS (s_b[tk][tx * TN]);
689+ // TODO: 手工实现 swizzle之行列号异或
690+ // https://zhuanlan.zhihu.com/p/722286440
674691
675692 #pragma unroll
676693 for (int tm = 0 ; tm < TM; tm++) {
0 commit comments