3
3
import itertools
4
4
import tilelang
5
5
import tilelang .language as T
6
- from tilelang .autotuner import AutoTuner
7
- from tilelang .carver .template import ConvTemplate
8
- from tilelang .carver .arch import CUDA
9
- from tilelang .carver .arch import CDNA
10
- from tilelang .carver .roller .rasterization import NoRasterization
11
6
12
7
13
8
def check_hopper ():
@@ -30,149 +25,36 @@ def main(A, B):
30
25
return main
31
26
32
27
33
- def get_configs (N , C , H , W , F , K , S , D , P , with_roller = False , topk = 15 ):
34
- if with_roller :
35
- arch = CDNA ("hip" ) if torch .version .hip is not None else CUDA ("cuda" )
36
- carve_template = ConvTemplate (
37
- N = N ,
38
- C = C ,
39
- H = H ,
40
- W = W ,
41
- F = F ,
42
- K = K ,
43
- S = S ,
44
- D = D ,
45
- P = P ,
46
- in_dtype = "float16" ,
47
- out_dtype = "float16" ,
48
- accum_dtype = "float" ,
49
- ).with_arch (arch )
50
-
51
- func = carve_template .equivalent_function ()
52
- assert func is not None , "Function is None"
53
- roller_hints = carve_template .recommend_hints (topk = topk )
54
- if roller_hints is None :
55
- raise ValueError ("No Roller Hints Found for TensorCore Scheduling" )
56
- configs = []
57
- for hint in roller_hints :
58
- config = {}
59
- block_m , block_n = hint .block
60
- warp_m , warp_n = hint .warp
61
- # block_rows, block_cols represents warp partitioning
62
- block_rows , block_cols = block_m // warp_m , block_n // warp_n
63
- config ["block_M" ] = block_m
64
- config ["block_N" ] = block_n
65
- config ["block_K" ] = hint .rstep [0 ]
66
- config ["num_stages" ] = hint .pipeline_stage if hint .pipeline_stage > 1 else 0
67
- config ["thread_num" ] = block_rows * block_cols * 32
68
- config ["enable_rasteration" ] = hint .rasterization_plan is not NoRasterization
69
- configs .append (config )
70
- else :
71
- block_M = [64 , 128 , 256 ]
72
- block_N = [64 , 128 , 256 ]
73
- block_K = [32 , 64 ]
74
- num_stages = [0 , 1 , 2 , 3 ]
75
- thread_num = [128 , 256 ]
76
- enable_rasterization = [True , False ]
77
- _configs = list (
78
- itertools .product (
79
- block_M ,
80
- block_N ,
81
- block_K ,
82
- num_stages ,
83
- thread_num ,
84
- enable_rasterization ,
85
- ))
86
-
87
- configs = [
88
- {
89
- "block_M" : c [0 ],
90
- "block_N" : c [1 ],
91
- "block_K" : c [2 ],
92
- "num_stages" : c [3 ],
93
- "thread_num" : c [4 ],
94
- "enable_rasteration" : c [5 ], # keep param name for backward-compat
95
- } for c in _configs
96
- ]
28
+ def get_configs ():
29
+ block_M = [64 , 128 , 256 ]
30
+ block_N = [64 , 128 , 256 ]
31
+ block_K = [32 , 64 ]
32
+ num_stages = [0 , 1 , 2 , 3 ]
33
+ thread_num = [128 , 256 ]
34
+ enable_rasterization = [True , False ]
35
+ _configs = list (
36
+ itertools .product (
37
+ block_M ,
38
+ block_N ,
39
+ block_K ,
40
+ num_stages ,
41
+ thread_num ,
42
+ enable_rasterization ,
43
+ ))
44
+
45
+ configs = [
46
+ {
47
+ "block_M" : c [0 ],
48
+ "block_N" : c [1 ],
49
+ "block_K" : c [2 ],
50
+ "num_stages" : c [3 ],
51
+ "thread_num" : c [4 ],
52
+ "enable_rasteration" : c [5 ], # keep param name for backward-compat
53
+ } for c in _configs
54
+ ]
97
55
return configs
98
56
99
57
100
- def get_best_config (N , C , H , W , F , K , S , D , P , ref_prog , with_roller = False ):
101
-
102
- @tilelang .jit (out_idx = [2 ])
103
- def kernel (
104
- block_M = None ,
105
- block_N = None ,
106
- block_K = None ,
107
- num_stages = None ,
108
- thread_num = None ,
109
- enable_rasteration = None ,
110
- ):
111
- dtype = "float16"
112
- accum_dtype = "float"
113
- KH , KW = K , K
114
- OH = (H + 2 * P - D * (K - 1 ) - 1 ) // S + 1
115
- OW = (W + 2 * P - D * (K - 1 ) - 1 ) // S + 1
116
- is_hopper = check_hopper ()
117
-
118
- @T .prim_func
119
- def main (
120
- data : T .Tensor ((N , H , W , C ), dtype ),
121
- kernel : T .Tensor ((KH , KW , C , F ), dtype ),
122
- out : T .Tensor ((N , OH , OW , F ), dtype ),
123
- ):
124
- with T .Kernel (
125
- T .ceildiv (F , block_N ), T .ceildiv (N * OH * OW , block_M ),
126
- threads = thread_num ) as (bx , by ):
127
- data_shared = T .alloc_shared ((block_M , block_K ), dtype )
128
- kernel_shared = T .alloc_shared ((block_K , block_N ), dtype )
129
- out_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
130
- out_shared = T .alloc_shared ((block_M , block_N ), dtype )
131
-
132
- kernel_flat = T .Tensor ((KH * KW * C , F ), dtype , kernel .data )
133
- out_flat = T .Tensor ((N * OH * OW , F ), dtype , out .data )
134
-
135
- T .annotate_layout ({
136
- out_shared : tilelang .layout .make_swizzled_layout (out_shared ),
137
- data_shared : tilelang .layout .make_swizzled_layout (data_shared ),
138
- kernel_shared : tilelang .layout .make_swizzled_layout (kernel_shared ),
139
- })
140
-
141
- T .clear (out_local )
142
- for k_iter in T .Pipelined (T .ceildiv (KH * KW * C , block_K ), num_stages = num_stages ):
143
- if is_hopper :
144
- T .c2d_im2col (data , data_shared , by , k_iter , KH , S , D , P )
145
- else :
146
- for i , j in T .Parallel (block_M , block_K ):
147
- k = k_iter * block_K + j
148
- m = by * block_M + i
149
- access_h = m % (OH * OW ) // OW * S + k // (KW * C ) * D - P
150
- access_w = m % OW * S + k // C % KW * D - P
151
- in_bound = ((access_h >= 0 ) and (access_w >= 0 ) and (access_h < H ) and
152
- (access_w < W ))
153
- data_shared [i , j ] = T .if_then_else (
154
- in_bound , data [m // (OH * OW ), access_h , access_w , k % C ], 0 )
155
- T .copy (kernel_flat [k_iter * block_K , bx * block_N ], kernel_shared )
156
- T .gemm (data_shared , kernel_shared , out_local )
157
-
158
- T .copy (out_local , out_shared )
159
- T .copy (out_shared , out_flat [by * block_M , bx * block_N ])
160
-
161
- return main
162
-
163
- autotuner = AutoTuner .from_kernel (
164
- kernel = kernel , configs = get_configs (N , C , H , W , F , K , S , D , P ,
165
- with_roller )).set_compile_args (
166
- out_idx = [2 ],
167
- target = "auto" ,
168
- ).set_profile_args (
169
- supply_type = tilelang .TensorSupplyType .Integer ,
170
- ref_prog = ref_prog ,
171
- skip_check = False ,
172
- )
173
- return autotuner .run (warmup = 3 , rep = 20 )
174
-
175
-
176
58
def get_heuristic_config () -> dict :
177
59
# Get CUDA device properties
178
60
if not torch .cuda .is_available ():
@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict:
210
92
}
211
93
212
94
95
+ @tilelang .autotune (configs = get_configs ())
213
96
@tilelang .jit (out_idx = [2 ])
214
97
def convolution (N ,
215
98
C ,
@@ -252,11 +135,10 @@ def main(
252
135
kernel_flat = T .Tensor ((KH * KW * C , F ), dtype , kernel .data )
253
136
out_flat = T .Tensor ((N * OH * OW , F ), dtype , out .data )
254
137
255
- T .annotate_layout ({
256
- out_shared : tilelang .layout .make_swizzled_layout (out_shared ),
257
- data_shared : tilelang .layout .make_swizzled_layout (data_shared ),
258
- kernel_shared : tilelang .layout .make_swizzled_layout (kernel_shared ),
259
- })
138
+ if is_hopper :
139
+ T .annotate_layout ({
140
+ out_shared : tilelang .layout .make_swizzled_layout (out_shared ),
141
+ })
260
142
261
143
T .clear (out_local )
262
144
for k_iter in T .Pipelined (T .ceildiv (KH * KW * C , block_K ), num_stages = num_stages ):
@@ -275,8 +157,11 @@ def main(
275
157
T .copy (kernel_flat [k_iter * block_K , bx * block_N ], kernel_shared )
276
158
T .gemm (data_shared , kernel_shared , out_local )
277
159
278
- T .copy (out_local , out_shared )
279
- T .copy (out_shared , out_flat [by * block_M , bx * block_N ])
160
+ if is_hopper :
161
+ T .copy (out_local , out_shared )
162
+ T .copy (out_shared , out_flat [by * block_M , bx * block_N ])
163
+ else :
164
+ T .copy (out_local , out_flat [by * block_M , bx * block_N ])
280
165
281
166
return main
282
167
@@ -296,9 +181,7 @@ def main(n: int = 128,
296
181
ref_prog = ref_program (S , P , D )
297
182
298
183
if use_autotune :
299
- result = get_best_config (N , C , H , W , F , K , S , D , P , ref_prog , with_roller )
300
- print (result .config )
301
- kernel = result .kernel
184
+ kernel = convolution (N , C , H , W , F , K , S , D , P )
302
185
else :
303
186
config = get_heuristic_config ()
304
187
kernel = convolution (N , C , H , W , F , K , S , D , P , ** config )
0 commit comments