@@ -283,6 +283,27 @@ void sgemm_wmma_m16n16k8_mma4x2_warp2x4_stage2(
283
283
constexpr int NUM_THREADS= (
284
284
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
285
285
286
+ // constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128
287
+ // constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128
288
+ // constexpr int BK = WMMA_K; // 8
289
+ // constexpr int OFFSET=0;
290
+
291
+ // int dev_id = 0;
292
+ // cudaGetDevice(&dev_id);
293
+ // cudaDeviceProp dev_prop;
294
+ // cudaGetDeviceProperties(&dev_prop, dev_id);
295
+ // int smem_max_size = (K_STAGE * BM * (BK+OFFSET) * sizeof(float) +
296
+ // K_STAGE * BK * (BN+OFFSET) * sizeof(float));
297
+ // smem_max_size = (smem_max_size < dev_prop.sharedMemPerMultiprocessor ?
298
+ // smem_max_size : dev_prop.sharedMemPerMultiprocessor);
299
+
300
+ // cudaFuncSetAttribute(
301
+ // sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel<
302
+ // WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N,
303
+ // WARP_TILE_M, WARP_TILE_N, K_STAGE, 0>,
304
+ // cudaFuncAttributeMaxDynamicSharedMemorySize,
305
+ // smem_max_size);
306
+
286
307
dim3 block (NUM_THREADS);
287
308
dim3 grid (div_ceil (N, WMMA_N * WMMA_TILE_N * WARP_TILE_N),
288
309
div_ceil (M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));
0 commit comments