Skip to content

Commit 37f1554

Browse files
authored
[HGEMM] Update toy-hgemm library 0.1.0 (#152)
* Update hgemm.py * Update hgemm_cublas.cu * Update hgemm_mma_stage_tn_cute.cu * Update hgemm_mma_stage.cu * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent edf80bb commit 37f1554

File tree

5 files changed

+27
-18
lines changed

5 files changed

+27
-18
lines changed

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,28 @@
1616

1717
<div id="contents"></div>
1818

19-
📚 **Modern CUDA Learn Notes with PyTorch** for Beginners: It includes **Tensor/CUDA Cores, TF32/F16/BF16/F8**, [📖150+ CUDA Kernels🔥🔥](#cuda-kernel) with PyTorch bindings, [📖30+ LLM/VLM🔥](#my-blogs-part-1), [📖40+ CV/C++...🔥](#my-blogs-part-2), [📖50+ CUDA/CuTe...🔥](#other-blogs) Blogs and [📖toy-hgemm library🔥🔥](./kernels/hgemm) which can achieve the performance of **cuBLAS**, check [📖HGEMM Supported Matrix👇](#hgemm-sgemm) for more details. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉
19+
📚 **Modern CUDA Learn Notes with PyTorch** for Beginners: It includes **Tensor/CUDA Cores, TF32/F16/BF16/F8**, [📖150+ CUDA Kernels🔥🔥](#cuda-kernel) with PyTorch bindings, [📖30+ LLM/VLM🔥](#my-blogs-part-1), [📖40+ CV/C++...🔥](#my-blogs-part-2), [📖50+ CUDA/CuTe...🔥](#other-blogs) Blogs and [📖toy-hgemm library⚡️⚡️](./kernels/hgemm) which can achieve `98%~100%` performance of **cuBLAS**, check [📖HGEMM Supported Matrix👇](#hgemm-sgemm) for techs details. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉
2020

2121
<div id="hgemm-sgemm"></div>
2222

2323
<div align='center'>
24-
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="150px" width="267px">
25-
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="150px" width="267px">
26-
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="150px" width="267px">
24+
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="170px" width="270px">
25+
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="170px" width="270px">
26+
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="170px" width="270px">
2727
</div>
2828

29-
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `99%~100%+` of its (`orange`🟠) performance. Please check [toy-hgemm library🔥🔥](./kernels/hgemm) for more details.
29+
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA/CuTe)` implemented in this repo (`blue`🔵) can achieve `98%~100%` of its (`orange`🟠) performance. Please check [toy-hgemm library⚡️⚡️](./kernels/hgemm) for more details.
3030

31-
|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
31+
|CUDA Cores|Sliced K (Loop over K)|Tile Block (BMxBK)|Tile Thread (t 8x8)|
3232
|:---:|:---:|:---:|:---:|
3333
|✔️|✔️|✔️|✔️|
34-
|WMMA(m16n16k16)|MMA(m16n8k16)|Pack LDST(128 bits)|SMEM Padding|
34+
|WMMA (m16n16k16)|MMA (m16n8k16)|Pack LDST (128 bits)|SMEM Padding|
3535
|✔️|✔️|✔️|✔️|
36-
|Copy Async|Tile MMA(More Threads)|Tile Warp(More Values)|Multi Stages|
36+
|Copy Async|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages (2/3/4)|
3737
|✔️|✔️|✔️|✔️|
38-
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle(CuTe)|
38+
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe)|
3939
|✔️|✔️|✔️|✔️|
40-
|Collective Store(Warp Shfl)|Row Major(NN)|Col Major(TN)|SGEMM F32/TF32|
40+
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
4141
|✔️|✔️|✔️|✔️|
4242

4343
## ©️Citations🎉🎉

kernels/hgemm/cublas/hgemm_cublas.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ int main(int argc, char *argv[]) {
184184
total_sec += this_sec;
185185
}
186186

187+
// 1 TFLOPS = 10^12 FLOPS
188+
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
187189
double avg_sec = total_sec / outer_repeat;
188190
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
189191

kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ int main() {
410410
total_sec += this_sec;
411411
}
412412

413+
// 1 TFLOPS = 10^12 FLOPS
414+
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
413415
double avg_sec = total_sec / outer_repeat;
414416
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
415417

kernels/hgemm/hgemm.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,20 @@ def run_benchmark(perf_func: callable,
136136
torch.cuda.synchronize()
137137

138138
end = time.time()
139-
total_time = (end - start) * 1000 # ms
140-
mean_time = total_time / iters
139+
total_time_secs = (end - start) # ms
140+
mean_time_secs = total_time_secs / iters
141141
out_info = f"{tag}"
142142
out_flat = out.flatten()
143143
out_val_first = out_flat[:2].detach().cpu().numpy().tolist()
144144
out_val_last = out_flat[-2:].detach().cpu().numpy().tolist()
145145
out_val = [out_val_first[0], out_val_last[-1]]
146146
out_val = [round(v, 8) for v in out_val]
147147
out_val = [f"{v:<12}"[:10] for v in out_val]
148-
TFLOPS = (2 * M * N * K) * 1e-9 / (mean_time)
149-
mean_time = str(f"{mean_time:<12}")[:8]
148+
# 1 TFLOPS = 10^12 FLOPS
149+
# ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
150+
TFLOPS = (2 * M * N * K) * 1e-12 / (mean_time_secs)
151+
mean_time_ms = mean_time_secs * 1000
152+
mean_time_ms = str(f"{mean_time_ms:<12}")[:8] # ms
150153
swizzle_stride = 'NOOP' if swizzle_stride == 1 else swizzle_stride
151154

152155
# caculate TFLOPS improved.
@@ -157,11 +160,11 @@ def run_benchmark(perf_func: callable,
157160
else:
158161
improve = 0
159162
MAX_TFLOPS = TFLOPS
160-
print(f"{out_info:>50}: {out_val}, time:{mean_time}ms, "
163+
print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, "
161164
f"swizzle<block>: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)")
162165
else:
163166
if not only_show_improved or "cublas" in tag:
164-
print(f"{out_info:>50}: {out_val}, time:{mean_time}ms, "
167+
print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, "
165168
f"swizzle<block>: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
166169
if show_matrix: print(out)
167170
if args.plot_flops:
@@ -186,7 +189,7 @@ def run_benchmark(perf_func: callable,
186189
gc.collect()
187190
torch.cuda.empty_cache()
188191
time.sleep(args.sleep_duration)
189-
return out, mean_time
192+
return out, mean_time_ms
190193

191194

192195
def get_topk_tflops():

kernels/hgemm/mma/hgemm_mma_stage.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,9 @@ int main() {
20322032
min_sec = min(min_sec, this_sec);
20332033
total_sec += this_sec;
20342034
}
2035-
2035+
2036+
// 1 TFLOPS = 10^12 FLOPS
2037+
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
20362038
double avg_sec = total_sec / outer_repeat;
20372039
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
20382040

0 commit comments

Comments
 (0)