Skip to content

Commit 4e2660b

Browse files
author
GitHub Actions
committed
Update docs
1 parent c8bbeca commit 4e2660b

33 files changed

+573
-38
lines changed

_images/Parallel.png

253 KB
Loading
295 KB
Loading
263 KB
Loading
Lines changed: 266 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,266 @@
1-
General Matrix-Matrix Multiplication
2-
====================================
1+
======================================================
2+
General Matrix-Matrix Multiplication with Tile Library
3+
======================================================
4+
5+
.. raw:: html
6+
7+
<div style="text-align: left;">
8+
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
9+
</div>
10+
11+
.. warning::
12+
13+
This document is still **experimental** and may be incomplete.
14+
Suggestions and improvements are highly encouraged—please submit a PR!
15+
16+
TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction:
17+
18+
* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM.
19+
20+
* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton’s programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
21+
22+
* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc.
23+
24+
.. _fig-overview:
25+
26+
.. figure:: ../_static/img/overview.png
27+
:align: center
28+
:width: 50%
29+
:alt: Overview
30+
31+
High-level overview of the TileLang compilation flow.
32+
33+
In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance.
34+
35+
----------------------------
36+
Why Another GPU DSL?
37+
----------------------------
38+
39+
TileLang emerged from the need for a DSL that:
40+
41+
1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed.
42+
2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.).
43+
3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning.
44+
45+
While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints.
46+
47+
----------------------------
48+
Matrix Multiplication Example
49+
----------------------------
50+
51+
In this section, we demonstrate how to write a 2D-tiled matrix multiplication kernel at Level 2 in TileLang.
52+
53+
.. figure:: ../_static/img/MatmulExample.png
54+
:align: center
55+
:alt: Matmul Example
56+
57+
Basic Structure
58+
^^^^^^^^^^^^^^^
59+
60+
Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses:
61+
62+
* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.).
63+
* **`T.alloc_shared(...)`** to allocate GPU shared memory.
64+
* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation.
65+
* **`T.Pipelined(...)`** to express software pipelining across the K dimension.
66+
* **`T.Parallel(...)`** to parallelize data copy loops.
67+
* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs).
68+
69+
.. code-block:: python
70+
71+
import tilelang
72+
import tilelang.language as T
73+
from tilelang.intrinsics import make_mma_swizzle_layout
74+
75+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
76+
@T.prim_func
77+
def main(
78+
A: T.Buffer((M, K), dtype),
79+
B: T.Buffer((K, N), dtype),
80+
C: T.Buffer((M, N), dtype),
81+
):
82+
# Initialize Kernel Context
83+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
84+
A_shared = T.alloc_shared((block_M, block_K), dtype)
85+
B_shared = T.alloc_shared((block_K, block_N), dtype)
86+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
87+
88+
# Optional layout hints (commented out by default)
89+
# T.annotate_layout({
90+
# A_shared: make_mma_swizzle_layout(A_shared),
91+
# B_shared: make_mma_swizzle_layout(B_shared),
92+
# })
93+
94+
# Optional: Enabling swizzle-based rasterization
95+
# T.use_swizzle(panel_size=10, enable=True)
96+
97+
# Clear local accumulation
98+
T.clear(C_local)
99+
100+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
101+
# Copy tile of A from global to shared memory
102+
T.copy(A[by * block_M, ko * block_K], A_shared)
103+
104+
# Parallel copy tile of B from global to shared memory
105+
for k, j in T.Parallel(block_K, block_N):
106+
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
107+
108+
# Perform a tile-level GEMM
109+
T.gemm(A_shared, B_shared, C_local)
110+
111+
# Copy result from local (register fragment) to global memory
112+
T.copy(C_local, C[by * block_M, bx * block_N])
113+
114+
return main
115+
116+
# 1. Create the TileLang function
117+
func = matmul(1024, 1024, 1024, 128, 128, 32)
118+
119+
# 2. JIT-compile the kernel for NVIDIA GPU
120+
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
121+
122+
import torch
123+
124+
# 3. Prepare input tensors in PyTorch
125+
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
126+
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
127+
128+
# 4. Invoke the JIT-compiled kernel
129+
c = jit_kernel(a, b)
130+
ref_c = a @ b
131+
132+
# 5. Validate correctness
133+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
134+
print("Kernel output matches PyTorch reference.")
135+
136+
# 6. Inspect generated CUDA code (optional)
137+
cuda_source = jit_kernel.get_kernel_source()
138+
print("Generated CUDA kernel:\n", cuda_source)
139+
140+
# 7. Profile performance
141+
profiler = jit_kernel.get_profiler()
142+
latency = profiler.do_bench()
143+
print(f"Latency: {latency} ms")
144+
145+
Key Concepts
146+
^^^^^^^^^^^^
147+
148+
1. **Kernel Context**:
149+
150+
.. code-block:: python
151+
152+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
153+
...
154+
155+
- This sets up the block grid dimensions based on :math:`\lceil N / block\_N \rceil` and :math:`\lceil M / block\_M \rceil`.
156+
- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads.
157+
158+
.. figure:: ../_static/img/Parallel.png
159+
:align: center
160+
:alt: Parallel
161+
162+
2. **Shared & Fragment Memory**:
163+
164+
.. code-block:: python
165+
166+
A_shared = T.alloc_shared((block_M, block_K), dtype)
167+
B_shared = T.alloc_shared((block_K, block_N), dtype)
168+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
169+
170+
- `T.alloc_shared` allocates shared memory across the entire thread block.
171+
- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
172+
173+
3. **Software Pipelining**:
174+
175+
.. code-block:: python
176+
177+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
178+
...
179+
180+
- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic.
181+
- The argument `num_stages=3` indicates the pipeline depth.
182+
183+
.. figure:: ../_static/img/software_pipeline_inference.png
184+
:align: center
185+
:alt: Software Pipeline Inference
186+
187+
4. **Parallel Copy**:
188+
189+
.. code-block:: python
190+
191+
for k, j in T.Parallel(block_K, block_N):
192+
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
193+
194+
- `T.Parallel` marks the loop for thread-level parallelization.
195+
- The compiler will map these loops to the available threads in the block.
196+
197+
5. **Tile-Level GEMM**:
198+
199+
.. code-block:: python
200+
201+
T.gemm(A_shared, B_shared, C_local)
202+
203+
- A single call that performs a tile-level matrix multiplication using the specified buffers.
204+
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
205+
206+
6. **Copying Back Results**:
207+
208+
.. code-block:: python
209+
210+
T.copy(C_local, C[by * block_M, bx * block_N])
211+
212+
- After computation, data in the local register fragment is written back to global memory.
213+
214+
----------------------------
215+
Comparison with Other DSLs
216+
----------------------------
217+
218+
TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also:
219+
220+
- Allows explicit memory layout annotations (e.g. `make_mma_swizzle_layout`).
221+
- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined.
222+
- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2.
223+
224+
-----------------------------------
225+
Performance on Different Platforms
226+
-----------------------------------
227+
228+
.. figure:: ../_static/img/op_benchmark_consistent_gemm_fp16.png
229+
:align: center
230+
:alt: Performance on Different Platforms
231+
232+
When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown:
233+
234+
- ~1.1x speedup over cuBLAS on RTX 4090
235+
- ~0.97x on A100 (on par with cuBLAS)
236+
- ~1.0x on H100
237+
- ~1.04x on MI300X
238+
- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware.
239+
240+
These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities.
241+
242+
----------------------------
243+
Conclusion
244+
----------------------------
245+
246+
This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code:
247+
248+
1. We allocated shared memory and register fragments.
249+
2. We pipelined the loading and computation along the K dimension.
250+
3. We used parallel copying to efficiently load tiles from global memory.
251+
4. We invoked `T.gemm` to dispatch a tile-level matrix multiply.
252+
5. We verified correctness against PyTorch and examined performance.
253+
254+
By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs.
255+
256+
For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1.
257+
258+
----------------------------
259+
Further Resources
260+
----------------------------
261+
262+
* `TileLang GitHub <https://github.com/tile-ai/tilelang>`_
263+
* `BitBLAS <https://github.com/tile-ai/bitblas>`_
264+
* `Triton <https://github.com/openai/triton>`_
265+
* `Cutlass <https://github.com/NVIDIA/cutlass>`_
266+
* `PyCUDA <https://documen.tician.de/pycuda/>`_

_static/img/Parallel.png

253 KB
Loading
295 KB
Loading
263 KB
Loading

deeplearning_operators/convolution.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@
178178
<ul class="current">
179179
<li class="toctree-l1"><a class="reference internal" href="elementwise.html">ElementWise Operators</a></li>
180180
<li class="toctree-l1"><a class="reference internal" href="gemv.html">General Matrix-Vector Multiplication (GEMV)</a></li>
181-
<li class="toctree-l1"><a class="reference internal" href="matmul.html">General Matrix-Matrix Multiplication</a></li>
181+
<li class="toctree-l1"><a class="reference internal" href="matmul.html">General Matrix-Matrix Multiplication with Tile Library</a></li>
182182
<li class="toctree-l1"><a class="reference internal" href="matmul_dequant.html">General Matrix-Matrix Multiplication with Dequantization</a></li>
183183
<li class="toctree-l1"><a class="reference internal" href="flash_attention.html">Flash Attention</a></li>
184184
<li class="toctree-l1"><a class="reference internal" href="flash_linear_attention.html">Flash Linear Attention</a></li>

deeplearning_operators/elementwise.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@
178178
<ul class="current">
179179
<li class="toctree-l1 current current-page"><a class="current reference internal" href="#">ElementWise Operators</a></li>
180180
<li class="toctree-l1"><a class="reference internal" href="gemv.html">General Matrix-Vector Multiplication (GEMV)</a></li>
181-
<li class="toctree-l1"><a class="reference internal" href="matmul.html">General Matrix-Matrix Multiplication</a></li>
181+
<li class="toctree-l1"><a class="reference internal" href="matmul.html">General Matrix-Matrix Multiplication with Tile Library</a></li>
182182
<li class="toctree-l1"><a class="reference internal" href="matmul_dequant.html">General Matrix-Matrix Multiplication with Dequantization</a></li>
183183
<li class="toctree-l1"><a class="reference internal" href="flash_attention.html">Flash Attention</a></li>
184184
<li class="toctree-l1"><a class="reference internal" href="flash_linear_attention.html">Flash Linear Attention</a></li>

deeplearning_operators/flash_attention.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@
178178
<ul class="current">
179179
<li class="toctree-l1"><a class="reference internal" href="elementwise.html">ElementWise Operators</a></li>
180180
<li class="toctree-l1"><a class="reference internal" href="gemv.html">General Matrix-Vector Multiplication (GEMV)</a></li>
181-
<li class="toctree-l1"><a class="reference internal" href="matmul.html">General Matrix-Matrix Multiplication</a></li>
181+
<li class="toctree-l1"><a class="reference internal" href="matmul.html">General Matrix-Matrix Multiplication with Tile Library</a></li>
182182
<li class="toctree-l1"><a class="reference internal" href="matmul_dequant.html">General Matrix-Matrix Multiplication with Dequantization</a></li>
183183
<li class="toctree-l1 current current-page"><a class="current reference internal" href="#">Flash Attention</a></li>
184184
<li class="toctree-l1"><a class="reference internal" href="flash_linear_attention.html">Flash Linear Attention</a></li>

0 commit comments

Comments
 (0)