Skip to content

Commit a10bcb4

Browse files
authored
[SGEMM] SGEMM TF32 Thread Block Swizzle (#84)
* Update sgemm.py * Update sgemm_wmma_tf32_stage.cu * Update sgemm_wmma_tf32_stage.cu * Update sgemm.py * Update README.md * Update hgemm_wmma_stage.cu * Update README.md * Update README.md * Update README.md * Update sgemm.cu
1 parent bc3d78e commit a10bcb4

File tree

6 files changed

+879
-313
lines changed

6 files changed

+879
-313
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
1010
</div>
1111

12-
🎉 **Modern CUDA Learn Notes with PyTorch** for **Beginners**: **fp32/tf32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, rope, embedding, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, gelu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, cp.async, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts reduce, MMA, etc).
12+
🎉 **Modern CUDA Learn Notes with PyTorch** for **Beginners**: **fp32/tf32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, rope, embedding, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, gelu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, cp.async, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts reduce, WMMA/MMA, block/warp swizzle, etc).
1313

1414
<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">
1515

@@ -119,6 +119,7 @@
119119
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
120120
| ✔️ [sgemm_t_8x8_sliced_k16...async](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
121121
| ✔️ [sgemm_wmma_m16n16k8...stage2/3*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
122+
| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
122123
| ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️|
123124
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
124125
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|

hgemm/hgemm_wmma_stage.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,16 +875,16 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(
875875
switch (stages)
876876
{
877877
case 2: // ~21KB
878-
LAUNCH_161616_STAGE_SWIZZLE_DSMEM_KERNEL(2, swizzle_stride);
878+
LAUNCH_161616_STAGE_SWIZZLE_KERNEL(2, swizzle_stride);
879879
break;
880880
case 3: // ~31KB
881-
LAUNCH_161616_STAGE_SWIZZLE_DSMEM_KERNEL(3, swizzle_stride);
881+
LAUNCH_161616_STAGE_SWIZZLE_KERNEL(3, swizzle_stride);
882882
break;
883883
case 4: // ~41K
884-
LAUNCH_161616_STAGE_SWIZZLE_DSMEM_KERNEL(4, swizzle_stride);
884+
LAUNCH_161616_STAGE_SWIZZLE_KERNEL(4, swizzle_stride);
885885
break;
886886
default:
887-
LAUNCH_161616_STAGE_SWIZZLE_DSMEM_KERNEL(2, swizzle_stride);
887+
LAUNCH_161616_STAGE_SWIZZLE_KERNEL(2, swizzle_stride);
888888
break;
889889
}
890890
} else {

0 commit comments

Comments
 (0)