Skip to content

Commit b448309

Browse files
authored
[Autotune][Conv] optimize convolution examples to use autotune (#866)
1 parent 9cbbbbc commit b448309

File tree

1 file changed

+38
-155
lines changed

1 file changed

+38
-155
lines changed

examples/convolution/example_convolution_autotune.py

Lines changed: 38 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
import itertools
44
import tilelang
55
import tilelang.language as T
6-
from tilelang.autotuner import AutoTuner
7-
from tilelang.carver.template import ConvTemplate
8-
from tilelang.carver.arch import CUDA
9-
from tilelang.carver.arch import CDNA
10-
from tilelang.carver.roller.rasterization import NoRasterization
116

127

138
def check_hopper():
@@ -30,149 +25,36 @@ def main(A, B):
3025
return main
3126

3227

33-
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
34-
if with_roller:
35-
arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda")
36-
carve_template = ConvTemplate(
37-
N=N,
38-
C=C,
39-
H=H,
40-
W=W,
41-
F=F,
42-
K=K,
43-
S=S,
44-
D=D,
45-
P=P,
46-
in_dtype="float16",
47-
out_dtype="float16",
48-
accum_dtype="float",
49-
).with_arch(arch)
50-
51-
func = carve_template.equivalent_function()
52-
assert func is not None, "Function is None"
53-
roller_hints = carve_template.recommend_hints(topk=topk)
54-
if roller_hints is None:
55-
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
56-
configs = []
57-
for hint in roller_hints:
58-
config = {}
59-
block_m, block_n = hint.block
60-
warp_m, warp_n = hint.warp
61-
# block_rows, block_cols represents warp partitioning
62-
block_rows, block_cols = block_m // warp_m, block_n // warp_n
63-
config["block_M"] = block_m
64-
config["block_N"] = block_n
65-
config["block_K"] = hint.rstep[0]
66-
config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
67-
config["thread_num"] = block_rows * block_cols * 32
68-
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
69-
configs.append(config)
70-
else:
71-
block_M = [64, 128, 256]
72-
block_N = [64, 128, 256]
73-
block_K = [32, 64]
74-
num_stages = [0, 1, 2, 3]
75-
thread_num = [128, 256]
76-
enable_rasterization = [True, False]
77-
_configs = list(
78-
itertools.product(
79-
block_M,
80-
block_N,
81-
block_K,
82-
num_stages,
83-
thread_num,
84-
enable_rasterization,
85-
))
86-
87-
configs = [
88-
{
89-
"block_M": c[0],
90-
"block_N": c[1],
91-
"block_K": c[2],
92-
"num_stages": c[3],
93-
"thread_num": c[4],
94-
"enable_rasteration": c[5], # keep param name for backward-compat
95-
} for c in _configs
96-
]
28+
def get_configs():
29+
block_M = [64, 128, 256]
30+
block_N = [64, 128, 256]
31+
block_K = [32, 64]
32+
num_stages = [0, 1, 2, 3]
33+
thread_num = [128, 256]
34+
enable_rasterization = [True, False]
35+
_configs = list(
36+
itertools.product(
37+
block_M,
38+
block_N,
39+
block_K,
40+
num_stages,
41+
thread_num,
42+
enable_rasterization,
43+
))
44+
45+
configs = [
46+
{
47+
"block_M": c[0],
48+
"block_N": c[1],
49+
"block_K": c[2],
50+
"num_stages": c[3],
51+
"thread_num": c[4],
52+
"enable_rasteration": c[5], # keep param name for backward-compat
53+
} for c in _configs
54+
]
9755
return configs
9856

9957

100-
def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):
101-
102-
@tilelang.jit(out_idx=[2])
103-
def kernel(
104-
block_M=None,
105-
block_N=None,
106-
block_K=None,
107-
num_stages=None,
108-
thread_num=None,
109-
enable_rasteration=None,
110-
):
111-
dtype = "float16"
112-
accum_dtype = "float"
113-
KH, KW = K, K
114-
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
115-
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
116-
is_hopper = check_hopper()
117-
118-
@T.prim_func
119-
def main(
120-
data: T.Tensor((N, H, W, C), dtype),
121-
kernel: T.Tensor((KH, KW, C, F), dtype),
122-
out: T.Tensor((N, OH, OW, F), dtype),
123-
):
124-
with T.Kernel(
125-
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
126-
threads=thread_num) as (bx, by):
127-
data_shared = T.alloc_shared((block_M, block_K), dtype)
128-
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
129-
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
130-
out_shared = T.alloc_shared((block_M, block_N), dtype)
131-
132-
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
133-
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
134-
135-
T.annotate_layout({
136-
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
137-
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
138-
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
139-
})
140-
141-
T.clear(out_local)
142-
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
143-
if is_hopper:
144-
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
145-
else:
146-
for i, j in T.Parallel(block_M, block_K):
147-
k = k_iter * block_K + j
148-
m = by * block_M + i
149-
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
150-
access_w = m % OW * S + k // C % KW * D - P
151-
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
152-
(access_w < W))
153-
data_shared[i, j] = T.if_then_else(
154-
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
155-
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
156-
T.gemm(data_shared, kernel_shared, out_local)
157-
158-
T.copy(out_local, out_shared)
159-
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
160-
161-
return main
162-
163-
autotuner = AutoTuner.from_kernel(
164-
kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P,
165-
with_roller)).set_compile_args(
166-
out_idx=[2],
167-
target="auto",
168-
).set_profile_args(
169-
supply_type=tilelang.TensorSupplyType.Integer,
170-
ref_prog=ref_prog,
171-
skip_check=False,
172-
)
173-
return autotuner.run(warmup=3, rep=20)
174-
175-
17658
def get_heuristic_config() -> dict:
17759
# Get CUDA device properties
17860
if not torch.cuda.is_available():
@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict:
21092
}
21193

21294

95+
@tilelang.autotune(configs=get_configs())
21396
@tilelang.jit(out_idx=[2])
21497
def convolution(N,
21598
C,
@@ -252,11 +135,10 @@ def main(
252135
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
253136
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
254137

255-
T.annotate_layout({
256-
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
257-
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
258-
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
259-
})
138+
if is_hopper:
139+
T.annotate_layout({
140+
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
141+
})
260142

261143
T.clear(out_local)
262144
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
@@ -275,8 +157,11 @@ def main(
275157
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
276158
T.gemm(data_shared, kernel_shared, out_local)
277159

278-
T.copy(out_local, out_shared)
279-
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
160+
if is_hopper:
161+
T.copy(out_local, out_shared)
162+
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
163+
else:
164+
T.copy(out_local, out_flat[by * block_M, bx * block_N])
280165

281166
return main
282167

@@ -296,9 +181,7 @@ def main(n: int = 128,
296181
ref_prog = ref_program(S, P, D)
297182

298183
if use_autotune:
299-
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller)
300-
print(result.config)
301-
kernel = result.kernel
184+
kernel = convolution(N, C, H, W, F, K, S, D, P)
302185
else:
303186
config = get_heuristic_config()
304187
kernel = convolution(N, C, H, W, F, K, S, D, P, **config)

0 commit comments

Comments
 (0)