Skip to content

Commit 0ea65c0

Browse files
author
GitHub Actions
committed
Update docs
1 parent 1d7a470 commit 0ea65c0

File tree

3 files changed

+326
-4
lines changed

3 files changed

+326
-4
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,148 @@
11
Auto-Tuning Techniques for Performance Optimization
22
===================================================
3+
<div style="text-align: left;">
4+
<em>Author:</em> <a href="https://github.com/yyttt6">yyttt6</a>
5+
</div>
6+
7+
## Overview
8+
9+
Auto-tuning a Tile Language program involves three main steps:
10+
11+
1. Implement the target program using Tile Language with reserved optimization parameters
12+
2. ​Provide candidate configurations through manual search or [auto-generation using Carver](#using-carver-to-auto-generate-candidate-configurations)
13+
3. Parallel compile and benchmark candidate configurations to identify the best performance
14+
15+
## Matrix Multiplication Example
16+
17+
The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation.
18+
19+
### Step 1: Implement with Reserved Parameters
20+
Users can implement matrix multiplication in Tile Language while reserving parameters for optimization:
21+
```python
22+
# Reserved parameters for optimization
23+
def kernel(
24+
block_M=None,
25+
block_N=None,
26+
block_K=None,
27+
num_stages=None,
28+
thread_num=None,
29+
enable_rasteration=None,
30+
):
31+
dtype = "float16"
32+
accum_dtype = "float"
33+
34+
# Matrix multiplication implementation
35+
@T.prim_func
36+
def main(
37+
A: T.Buffer((M, K), dtype),
38+
B: T.Buffer((N, K), dtype),
39+
C: T.Buffer((M, N), dtype),
40+
):
41+
# ...existing code...
42+
43+
return main
44+
```
45+
### Step 2: Generate Candidate Configurations
46+
Manually define configurations or use combinatorial generation:
47+
```python
48+
configs = [
49+
{
50+
"block_M": 128,
51+
"block_N": 128,
52+
"block_K": 128,
53+
"num_stages": 3,
54+
"thread_num": 128,
55+
"enable_rasteration": True
56+
},
57+
{
58+
"block_M": 32,
59+
"block_N": 32,
60+
"block_K": 32,
61+
"num_stages": 0,
62+
"thread_num": 32,
63+
"enable_rasteration": False
64+
},
65+
# ...additional configurations...
66+
]
67+
```
68+
It can also be given by combinatorial traversal of different parameters
69+
```python
70+
import itertools
71+
72+
block_M = [64, 128, 256]
73+
block_N = [64, 128, 256]
74+
block_K = [32, 64]
75+
num_stages = [0, 1, 2, 3]
76+
thread_num = [128, 256]
77+
enable_rasterization = [True, False]
78+
_configs = list(
79+
itertools.product(
80+
block_M,
81+
block_N,
82+
block_K,
83+
num_stages,
84+
thread_num,
85+
enable_rasterization,
86+
))
87+
88+
configs = [
89+
{
90+
"block_M": c[0],
91+
"block_N": c[1],
92+
"block_K": c[2],
93+
"num_stages": c[3],
94+
"thread_num": c[4],
95+
"enable_rasteration": c[5]
96+
} for c in _configs
97+
]
98+
```
99+
### Step 3: Compile and Benchmark
100+
Configure JIT compilation and benchmarking settings:
101+
```python
102+
autotuner = AutoTuner.from_kernel(
103+
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
104+
out_idx=[-1],
105+
supply_type=tl.TensorSupplyType.Integer,
106+
ref_prog=ref_program,
107+
skip_check=False,
108+
target="auto",
109+
)
110+
result = autotuner.run(warmup=3, rep=20)
111+
out_c = result.kernel(a, b)
112+
```
113+
The result object contains optimized kernel implementation which can be used by users directly
114+
115+
## Using Carver to Auto-Generate Candidate Configurations
116+
117+
Carver is a lightweight framework for generating and ranking tile configurations (also known as tiling strategies, blocking schemes, or scheduling hints) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels.
118+
119+
or common operators, Carver provides pre-built templates (e.g., `MatmulTemplate`):
120+
121+
```python
122+
# Configure Matmul template
123+
arch = CUDA("cuda")
124+
carve_template = MatmulTemplate(
125+
M=M,
126+
N=N,
127+
K=K,
128+
in_dtype="float16",
129+
out_dtype="float16",
130+
accum_dtype="float",
131+
).with_arch(arch)
132+
133+
# Generate top-k optimization hints (topk=10 recommended)
134+
roller_hints = carve_template.recommend_hints(topk=10)
135+
136+
# Configure candidate parameters
137+
for hint in roller_hints:
138+
139+
# ...existing code...
140+
141+
config["block_M"] = block_m
142+
config["block_N"] = block_n
143+
config["block_K"] = hint.rstep[0]
144+
config["num_stages"] = hint.pipeline_stage
145+
config["thread_num"] = block_rows * block_cols * 32
146+
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
147+
148+
```

0 commit comments

Comments
 (0)