Skip to content

Commit 497b8c3

Browse files
authored
feat: add torch.compile blogs (#350)
* fix comments * fix comments * fix comments * fix comments * add torch.compile blogs * add torch.compile blogs * add torch.compile blogs
1 parent 0ba5390 commit 497b8c3

27 files changed

+46
-115
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ repos:
1010
- id: end-of-file-fixer
1111
- id: check-yaml
1212
args: [--allow-multiple-documents]
13-
- id: check-toml
14-
- id: check-ast
1513
- id: check-added-large-files
1614
- id: check-merge-conflict
1715
- id: check-shebang-scripts-are-executable

README.md

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
<!---
2-
<img src='https://github.com/user-attachments/assets/9306862b-2a30-4a87-bb33-0fde9e9d7cea' width=250 >
3-
<a href="#cuda-kernel">📚200+ CUDA Kernels</a> | <a href="#my-blogs-part-1"> 📚100+ LLM/CUDA Blogs</a> | <a href="#HGEMM-bench"> ⚡️HGEMM MMA</a> | <a href="#fa-mma-bench"> ⚡️FA-2 MMA </a> <p>
4-
<img src='https://github.com/user-attachments/assets/b2578723-b7a7-4d8f-bcd1-5008947b808a' >
5-
<div align="center">
6-
<p align="center">
7-
<a href="#contribute">愿以青衿涉险苦,为君先踏棘荆途。他年若览通衢阔,莫忘初逢问路吾。</a>
8-
</p>
9-
</div>
10-
--->
11-
12-
131
<div align="center">
142
<p align="center">
153
<h2>📚 LeetCUDA: Modern CUDA Learn Notes with PyTorch for Beginners 🐑</h2>
@@ -24,14 +12,14 @@
2412
<img src=https://img.shields.io/github/stars/xlite-dev/LeetCUDA.svg?style=social >
2513
<img src=https://img.shields.io/badge/Release-v3.0.6-brightgreen.svg >
2614
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
27-
</div>
15+
</div>
2816
</div>
2917

3018
📚 **LeetCUDA**: It includes **Tensor/CUDA Cores, TF32/F16/BF16/F8**, [📖200+ CUDA Kernels🔥](#cuda-kernel) with PyTorch, [📖100+ LLM/CUDA🔥](#my-blogs-part-1) blogs, [📖HGEMM⚡️](./kernels/hgemm) which can achieve `98%~100%` TFLOPS of **cuBLAS**, and [📖flash-attn⚡️](./kernels/flash-attn) using Tensor Cores with pure MMA PTX. ♥️ Please consider to leave a ⭐️ Star to support me, my bro ~ ♥️
3119

3220
<div align="center">
3321
<p align="center">
34-
<a href="#contribute">🔥🔥 PR Welcome: Add Your Kernel to LeetCUDA! Let's make it Awesome together! 🎉🎉</a>
22+
<a href="#contribute">🔥🔥 PR Welcome: Add Your Kernel to LeetCUDA! Let's make it Awesome together! 🎉🎉</a> <br>
3523
<a href=https://github.com/xlite-dev/LeetCUDA/graphs/contributors > <img src=https://opencollective.com/leetcuda/contributors.svg height=40px > </a>
3624
</p>
3725
</div>
@@ -52,7 +40,7 @@
5240
## 📖 News 🔥🔥
5341
<div id="news"></div>
5442

55-
- [2025-06-16]: [🤗CacheDiT](https://github.com/vipshop/cache-dit) is release! A **Training-free** and **Easy-to-use** Cache Acceleration Toolbox for Diffusion Transformers (**DBCache**, **DBPrune**, **FBCache**, etc)🔥. Feel free to take a try!
43+
- [2025-06-16]: [🤗CacheDiT](https://github.com/vipshop/cache-dit) is release! A **Training-free** and **Easy-to-use** Cache Acceleration Toolbox for Diffusion Transformers (**DBCache**, **DBPrune**, **FBCache**, etc)🔥. Feel free to take a try!
5644

5745
<div align='center'>
5846
<img src='https://github.com/user-attachments/assets/a5ec4320-d2f9-4254-888a-170b2d9e3784' height=170px>
@@ -77,31 +65,6 @@
7765

7866
## 📖 Contents
7967
<div id="contents"></div>
80-
<!---
81-
- [📖 HGEMM-MMA 🎉🎉](#HGEMM-bench)
82-
- [📚 CUDA/Tensor Cores](#HGEMM-bench)
83-
- [📚 Tile Block(Br, Bc)](#HGEMM-bench)
84-
- [📚 Tile MMAs/Warps](#HGEMM-bench)
85-
- [📚 Pack LDST(128 bits)](#HGEMM-bench)
86-
- [📚 Multi Stages(2~4)](#HGEMM-bench)
87-
- [📚 Block/Warp Swizzle](#HGEMM-bench)
88-
- [📚 SMEM Swizzle](#HGEMM-bench)
89-
- [📚 Register Double Buffers](#HGEMM-bench)
90-
- [📚 Collective Store(Shfl)](#HGEMM-bench)
91-
- [📚 Layout NN/TN](#HGEMM-bench)
92-
- [📖 FlashAttention-MMA 🎉🎉](#fa-mma-bench)
93-
- [📖 200+ CUDA Kernels 🔥🔥](#cuda-kernel)
94-
- [📖 100+ 高性能计算文章 💡💡](#my-blogs-part-1)
95-
- [📚 大模型推理优化原理](#my-blogs-part-1)
96-
- [📚 大模型分布式训推原理](#my-blogs-part-1)
97-
- [📚 CV/C++/模型部署优化](#my-blogs-part-1)
98-
- [📚 CUDA优化入门与实践](#other-blogs)
99-
- [📚 Tensor Cores入门教程](#other-blogs)
100-
- [📚 CuTe系列详解与实践](#other-blogs)
101-
- [📚 GPU指令集架构精解](#other-blogs)
102-
- [📚 GPU通信架构精解](#other-blogs)
103-
- [📖 How to Contribute 👀👇](#contribute)
104-
--->
10568

10669
- [📖 HGEMM-MMA 🎉🎉](#HGEMM-bench)
10770
- [📖 FlashAttention-MMA 🎉🎉](#fa-mma-bench)
@@ -521,7 +484,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
521484

522485
|📖 类型-标题|📖 作者| 📖 推荐 |
523486
|:---|:---|:---|
524-
| [[Diffusion推理]📖DiT推理加速综述: Caching](https://zhuanlan.zhihu.com/p/711223667)|@DefTruth|⭐️⭐️⭐|
487+
| [[Diffusion推理]📖DiT推理加速综述: Caching](https://zhuanlan.zhihu.com/p/711223667)|@DefTruth|⭐️⭐️⭐|
525488
| [[Triton编程][基础]📖Triton极简入门: Triton Vector Add](https://zhuanlan.zhihu.com/p/1902778199261291694)|@DefTruth|⭐️⭐️⭐|
526489
| [[Triton编程][基础]📖Triton Fused Softmax Kernel详解: 从Python源码到PTX](https://zhuanlan.zhihu.com/p/1899562146477609112)|@DefTruth|⭐️⭐️⭐|
527490
| [[Triton编程][基础]📖vLLM Triton Merge Attention States Kernel详解](https://zhuanlan.zhihu.com/p/1904937907703243110)|@DefTruth|⭐️⭐️⭐|
@@ -665,6 +628,15 @@ The kernels listed here will guide you through a step-by-step progression, rangi
665628
| [[Tensor Cores]📖Nvidia Tensor Core-MMA PTX编程入门](https://zhuanlan.zhihu.com/p/621855199)|@木子知|⭐️⭐️⭐️|
666629
| [[Tensor Cores]📖CUDA Ampere Tensor Core HGEMM 矩阵乘法优化](https://zhuanlan.zhihu.com/p/555339335)|@nicholaswilde|⭐️⭐️⭐️|
667630
| [[GPU通信架构][精解]📖NVIDIA GPGPU(四)- 通信架构](https://zhuanlan.zhihu.com/p/680262016)|@Bruce|⭐️⭐️⭐️|
631+
| [[torch.compile][原理]📖Torch.compile流程解析: 介绍](https://zhuanlan.zhihu.com/p/9418379234)|@StarCap|⭐️⭐️⭐️|
632+
| [[torch.compile][原理]📖Torch.compile流程解析: TorchDynamo](https://zhuanlan.zhihu.com/p/9640728231)|@StarCap|⭐️⭐️⭐️|
633+
| [[torch.compile][原理]📖Torch.compile流程解析: AOTAutograd](https://zhuanlan.zhihu.com/p/9997263922)|@StarCap|⭐️⭐️⭐️|
634+
| [[torch.compile][原理]📖Torch.compile流程解析: TorchInductor](https://zhuanlan.zhihu.com/p/11224299472)|@StarCap|⭐️⭐️⭐️|
635+
| [[torch.compile][原理]📖Torch.compile流程解析: 算子融合](https://zhuanlan.zhihu.com/p/21053905491)|@StarCap|⭐️⭐️⭐️|
636+
| [[torch.compile][实践]📖Torch.compile使用指南](https://zhuanlan.zhihu.com/p/620163218)|@jhang|⭐️⭐️⭐️|
637+
| [[torch.compile][实践]📖Torch.compile详细示例解析教程](https://zhuanlan.zhihu.com/p/855291863)|@Bbuf|⭐️⭐️⭐️|
638+
| [[torch.compile][原理]📖一文搞懂TorchDynamo原理](https://zhuanlan.zhihu.com/p/630933479)|@吾乃阿尔法|⭐️⭐️⭐️|
639+
| [[torch.compile][原理]📖理解torch.compile基本原理和使用方式](https://zhuanlan.zhihu.com/p/12712224407)|@俯仰|⭐️⭐️⭐️|
668640

669641
## ©️License ([©️back👆🏻](#contents))
670642

kernels/dot-product/dot_product.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
1818
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
1919

20-
// -------------------------------------- FP32
21-
// -------------------------------------- Warp Reduce Sum
20+
// FP32
21+
// Warp Reduce Sum
2222
template <const int kWarpSize = WARP_SIZE>
2323
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
2424
#pragma unroll
@@ -87,8 +87,8 @@ __global__ void dot_prod_f32x4_f32_kernel(float *a, float *b, float *y, int N) {
8787
atomicAdd(y, prod);
8888
}
8989

90-
// -------------------------------------- FP16
91-
// -------------------------------------- Warp Reduce Sum: Half
90+
// FP16
91+
// Warp Reduce Sum: Half
9292
template <const int kWarpSize = WARP_SIZE>
9393
__device__ __forceinline__ half warp_reduce_sum_f16_f16(half val) {
9494
#pragma unroll
@@ -199,8 +199,6 @@ __global__ void dot_prod_f16x8_pack_f32_kernel(half *a, half *b, float *y,
199199
atomicAdd(y, prod);
200200
}
201201

202-
// --------------------- PyTorch bindings for custom kernel
203-
// -----------------------
204202
#define STRINGFY(str) #str
205203
#define TORCH_BINDING_COMMON_EXTENSION(func) \
206204
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/hardswish/hardswish.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,17 @@
1919
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
2020
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
2121

22-
// 定义 CHECK_TORCH_TENSOR_DTYPE 宏
2322
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
2423
if (((T).options().dtype() != (th_type))) { \
2524
std::cout << "Tensor Info:" << (T).options() << std::endl; \
2625
throw std::runtime_error("Tensor dtype must be " #th_type); \
2726
}
2827

29-
// 定义 TORCH_BINDING_COMMON_EXTENSION 宏
3028
#define STRINGFY(str) #str
3129
#define TORCH_BINDING_COMMON_EXTENSION(func) \
3230
m.def(STRINGFY(func), &func, STRINGFY(func));
3331

34-
// HARDSWISH 计算函数
35-
// FP32
32+
// FP32
3633
__device__ __forceinline__ float hardswish(float x) {
3734
if (x >= THRESHOLD_A) {
3835
return x;
@@ -43,7 +40,7 @@ __device__ __forceinline__ float hardswish(float x) {
4340
}
4441
}
4542

46-
// FP16
43+
// FP16
4744
__device__ __forceinline__ half hardswish_half(half x) {
4845
if (x > __float2half(THRESHOLD_A)) {
4946
return x;
@@ -54,8 +51,7 @@ __device__ __forceinline__ half hardswish_half(half x) {
5451
}
5552
}
5653

57-
// CUDA 核函数
58-
// FP32
54+
// FP32
5955
__global__ void hardswish_f32_kernel(float *x, float *y, int N) {
6056
int idx = blockIdx.x * blockDim.x + threadIdx.x;
6157
if (idx < N)
@@ -75,7 +71,7 @@ __global__ void hardswish_f32x4_kernel(float *x, float *y, int N) {
7571
}
7672
}
7773

78-
// FP16
74+
// FP16
7975
__global__ void hardswish_f16_kernel(half *x, half *y, int N) {
8076
int idx = blockIdx.x * blockDim.x + threadIdx.x;
8177
if (idx < N)

kernels/hgemm/cublas/hgemm_cublas.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ int main(int argc, char *argv[]) {
173173
}
174174
// build torch python binding
175175
#else
176-
// --------------------- PyTorch bindings for custom kernel
177-
// -----------------------
176+
178177
#include <torch/extension.h>
179178
#include <torch/types.h>
180179

kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,7 @@ int main() {
469469

470470
#include <torch/extension.h>
471471
#include <torch/types.h>
472-
// --------------------- PyTorch bindings for custom kernel
473-
// -----------------------
472+
474473
#define STRINGFY(str) #str
475474
#define TORCH_BINDING_COMMON_EXTENSION(func) \
476475
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/hgemm/mma/basic/hgemm_mma.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,6 @@ __global__ void __launch_bounds__(256)
288288
}
289289
}
290290

291-
// --------------------- PyTorch bindings for custom kernel
292-
// -----------------------
293291
#define STRINGFY(str) #str
294292
#define TORCH_BINDING_COMMON_EXTENSION(func) \
295293
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/hgemm/mma/basic/hgemm_mma_stage.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,8 +2039,6 @@ int main(int argc, char *argv[]) {
20392039

20402040
#else
20412041

2042-
// --------------------- PyTorch bindings for custom kernel
2043-
// -----------------------
20442042
#include <torch/extension.h>
20452043
#include <torch/types.h>
20462044
#define STRINGFY(str) #str

kernels/hgemm/mma/basic/hgemm_mma_stage_tn.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,6 @@ int main(int argc, char *argv[]) {
492492

493493
#else
494494

495-
// --------------------- PyTorch bindings for custom kernel
496-
// -----------------------
497495
#include <torch/extension.h>
498496
#include <torch/types.h>
499497
#define STRINGFY(str) #str

kernels/hgemm/mma/swizzle/hgemm_mma_stage_swizzle.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,8 +738,6 @@ int main(int argc, char *argv[]) {
738738

739739
#else
740740

741-
// --------------------- PyTorch bindings for custom kernel
742-
// -----------------------
743741
#include <torch/extension.h>
744742
#include <torch/types.h>
745743
#define STRINGFY(str) #str

0 commit comments

Comments
 (0)