Skip to content

Commit ba4998d

Browse files
authored
Update sgemm_wmma_tf32_stage.cu (#79)
1 parent cf4f9d7 commit ba4998d

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

sgemm/sgemm_wmma_tf32_stage.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,27 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stage2(
283283
constexpr int NUM_THREADS= (
284284
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
285285

286+
// constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128
287+
// constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128
288+
// constexpr int BK = WMMA_K; // 8
289+
// constexpr int OFFSET=0;
290+
291+
// int dev_id = 0;
292+
// cudaGetDevice(&dev_id);
293+
// cudaDeviceProp dev_prop;
294+
// cudaGetDeviceProperties(&dev_prop, dev_id);
295+
// int smem_max_size = (K_STAGE * BM * (BK+OFFSET) * sizeof(float) +
296+
// K_STAGE * BK * (BN+OFFSET) * sizeof(float));
297+
// smem_max_size = (smem_max_size < dev_prop.sharedMemPerMultiprocessor ?
298+
// smem_max_size : dev_prop.sharedMemPerMultiprocessor);
299+
300+
// cudaFuncSetAttribute(
301+
// sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel<
302+
// WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N,
303+
// WARP_TILE_M, WARP_TILE_N, K_STAGE, 0>,
304+
// cudaFuncAttributeMaxDynamicSharedMemorySize,
305+
// smem_max_size);
306+
286307
dim3 block(NUM_THREADS);
287308
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N),
288309
div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));

0 commit comments

Comments
 (0)