Skip to content

Commit adc8c5e

Browse files
author
GitHub Actions
committed
Update docs
1 parent f95aa3f commit adc8c5e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+965
-1265
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# General Matrix-Matrix Multiplication with Tile Library
2+
3+
<div style="text-align: left;">
4+
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
5+
</div>
6+
7+
:::{warning}
8+
:class: myclass1 myclass2
9+
:name: a-tip-reference
10+
11+
This document is still **experimental** and may be incomplete.
12+
Suggestions and improvements are highly encouraged—please submit a PR!
13+
:::
14+
15+
TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction:
16+
17+
* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM.
18+
19+
* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
20+
21+
* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc.
22+
23+
```{figure} ../_static/img/overview.png
24+
:width: 50%
25+
:alt: Overview
26+
:align: center
27+
28+
Figure 1: High-level overview of the TileLang compilation flow.
29+
```
30+
31+
In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance.
32+
33+
## Why Another GPU DSL?
34+
35+
TileLang emerged from the need for a DSL that:
36+
37+
1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed.
38+
2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.).
39+
3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning.
40+
41+
While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints.
42+
43+
## Matrix Multiplication Example
44+
45+
```{figure} ../_static/img/MatmulExample.png
46+
:alt: Matmul Example
47+
:align: center
48+
49+
```
50+
51+
### Basic Structure
52+
53+
Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses:
54+
55+
* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.).
56+
* **`T.alloc_shared(...)`** to allocate GPU shared memory.
57+
* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation.
58+
* **`T.Pipelined(...)`** to express software pipelining across the K dimension.
59+
* **`T.Parallel(...)`** to parallelize data copy loops.
60+
* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs).
61+
62+
```python
63+
import tilelang
64+
import tilelang.language as T
65+
from tilelang.intrinsics import make_mma_swizzle_layout
66+
67+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
68+
@T.prim_func
69+
def main(
70+
A: T.Buffer((M, K), dtype),
71+
B: T.Buffer((K, N), dtype),
72+
C: T.Buffer((M, N), dtype),
73+
):
74+
# Initialize Kernel Context
75+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
76+
A_shared = T.alloc_shared((block_M, block_K), dtype)
77+
B_shared = T.alloc_shared((block_K, block_N), dtype)
78+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
79+
80+
# Optional layout hints (commented out by default)
81+
# T.annotate_layout({
82+
# A_shared: make_mma_swizzle_layout(A_shared),
83+
# B_shared: make_mma_swizzle_layout(B_shared),
84+
# })
85+
86+
# Optional: Enabling swizzle-based rasterization
87+
# T.use_swizzle(panel_size=10, enable=True)
88+
89+
# Clear local accumulation
90+
T.clear(C_local)
91+
92+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
93+
# Copy tile of A from global to shared memory
94+
T.copy(A[by * block_M, ko * block_K], A_shared)
95+
96+
# Parallel copy tile of B from global to shared memory
97+
for k, j in T.Parallel(block_K, block_N):
98+
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
99+
100+
# Perform a tile-level GEMM
101+
T.gemm(A_shared, B_shared, C_local)
102+
103+
# Copy result from local (register fragment) to global memory
104+
T.copy(C_local, C[by * block_M, bx * block_N])
105+
106+
return main
107+
108+
# 1. Create the TileLang function
109+
func = matmul(1024, 1024, 1024, 128, 128, 32)
110+
111+
# 2. JIT-compile the kernel for NVIDIA GPU
112+
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
113+
114+
import torch
115+
116+
# 3. Prepare input tensors in PyTorch
117+
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
118+
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
119+
120+
# 4. Invoke the JIT-compiled kernel
121+
c = jit_kernel(a, b)
122+
ref_c = a @ b
123+
124+
# 5. Validate correctness
125+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
126+
print("Kernel output matches PyTorch reference.")
127+
128+
# 6. Inspect generated CUDA code (optional)
129+
cuda_source = jit_kernel.get_kernel_source()
130+
print("Generated CUDA kernel:\n", cuda_source)
131+
132+
# 7. Profile performance
133+
profiler = jit_kernel.get_profiler()
134+
latency = profiler.do_bench()
135+
print(f"Latency: {latency} ms")
136+
```
137+
138+
### Key Concepts
139+
140+
1. **Kernel Context**:
141+
142+
```python
143+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
144+
...
145+
```
146+
147+
- This sets up the block grid dimensions based on N/block_N and M/block_M.
148+
- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads.
149+
150+
151+
```{figure} ../_static/img/Parallel.png
152+
:alt: Parallel
153+
:align: center
154+
155+
```
156+
157+
158+
2. **Shared & Fragment Memory**:
159+
160+
```python
161+
A_shared = T.alloc_shared((block_M, block_K), dtype)
162+
B_shared = T.alloc_shared((block_K, block_N), dtype)
163+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
164+
```
165+
166+
- `T.alloc_shared` allocates shared memory across the entire thread block.
167+
- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
168+
169+
3. **Software Pipelining**:
170+
171+
```python
172+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
173+
...
174+
```
175+
176+
- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic.
177+
- The argument `num_stages=3` indicates the pipeline depth.
178+
179+
```{figure} ../_static/img/software_pipeline_inference.png
180+
:alt: Software Pipeline Inference
181+
:align: center
182+
183+
```
184+
185+
186+
4. **Parallel Copy**:
187+
188+
```python
189+
for k, j in T.Parallel(block_K, block_N):
190+
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
191+
```
192+
193+
- `T.Parallel` marks the loop for thread-level parallelization.
194+
- The compiler will map these loops to the available threads in the block.
195+
196+
5. **Tile-Level GEMM**:
197+
198+
```python
199+
T.gemm(A_shared, B_shared, C_local)
200+
```
201+
202+
- A single call that performs a tile-level matrix multiplication using the specified buffers.
203+
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
204+
205+
6. **Copying Back Results**:
206+
207+
```python
208+
T.copy(C_local, C[by * block_M, bx * block_N])
209+
```
210+
211+
- After computation, data in the local register fragment is written back to global memory.
212+
213+
## Comparison with Other DSLs
214+
215+
TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also:
216+
217+
- Allows explicit memory layout annotations (e.g. `make_mma_swizzle_layout`).
218+
- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined.
219+
- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2.
220+
221+
## Performance on Different Platforms
222+
223+
```{figure} ../_static/img/op_benchmark_consistent_gemm_fp16.png
224+
:alt: Performance on Different Platforms
225+
:align: center
226+
227+
```
228+
229+
When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown:
230+
231+
- ~1.1x speedup over cuBLAS on RTX 4090
232+
- ~0.97x on A100 (on par with cuBLAS)
233+
- ~1.0x on H100
234+
- ~1.04x on MI300X
235+
- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware.
236+
237+
These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities.
238+
239+
## Conclusion
240+
241+
This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code:
242+
243+
1. We allocated shared memory and register fragments.
244+
2. We pipelined the loading and computation along the K dimension.
245+
3. We used parallel copying to efficiently load tiles from global memory.
246+
4. We invoked `T.gemm` to dispatch a tile-level matrix multiply.
247+
5. We verified correctness against PyTorch and examined performance.
248+
249+
By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs.
250+
251+
For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1.
252+
253+
## Further Resources
254+
255+
* [TileLang GitHub](https://github.com/tile-ai/tilelang)
256+
* [BitBLAS](https://github.com/tile-ai/bitblas)
257+
* [Triton](https://github.com/openai/triton)
258+
* [Cutlass](https://github.com/NVIDIA/cutlass)
259+
* [PyCUDA](https://documen.tician.de/pycuda/)

0 commit comments

Comments
 (0)