@@ -123,35 +123,24 @@ Below is an example that demonstrates more advanced features: layout annotation,
123
123
``` python
124
124
import tilelang
125
125
import tilelang.language as T
126
- # `make_mma_swizzle_layout` is a python defined layout function
127
- # specifically designed for for MMA operations
128
- # which ensures the consistency with the nvidia CUTLASS Library.
129
- # to avoid bank conflicts and maximize the performance.
130
- from tilelang.intrinsics import (
131
- make_mma_swizzle_layout as make_swizzle_layout,)
132
-
133
- # add decorator @tilelang.jit if you want to return a torch function
134
- # @tilelang.jit
126
+
127
+ # @tilelang.jit(target="cuda")
128
+ # target currently can be "cuda" or "hip" or "cpu".
129
+ # if not specified, it will be inferred from the input tensors during compile time
130
+ @tilelang.jit
135
131
def matmul (M , N , K , block_M , block_N , block_K , dtype = " float16" , accum_dtype = " float" ):
136
132
137
133
@T.prim_func
138
- def main (
139
- A : T.Tensor((M, K), dtype),
140
- B : T.Tensor((K, N), dtype),
141
- C : T.Tensor((M, N), dtype),
134
+ def matmul_relu_kernel (
135
+ A : T.Tensor((M, K), dtype),
136
+ B : T.Tensor((K, N), dtype),
137
+ C : T.Tensor((M, N), dtype),
142
138
):
143
139
# Initialize Kernel Context
144
140
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads = 128 ) as (bx, by):
145
141
A_shared = T.alloc_shared((block_M, block_K), dtype)
146
142
B_shared = T.alloc_shared((block_K, block_N), dtype)
147
- C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
148
-
149
- # Apply layout optimizations or define your own layout (Optional)
150
- # If not specified, we will deduce the layout automatically
151
- # T.annotate_layout({
152
- # A_shared: make_swizzle_layout(A_shared),
153
- # B_shared: make_swizzle_layout(B_shared),
154
- # })
143
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
155
144
156
145
# Enable rasterization for better L2 cache locality (Optional)
157
146
# T.use_swizzle(panel_size=10, enable=True)
@@ -164,53 +153,58 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
164
153
# This is a sugar syntax for parallelized copy
165
154
T.copy(A[by * block_M, ko * block_K], A_shared)
166
155
167
- # Demonstrate parallelized copy from global to shared for B
168
- for k, j in T.Parallel(block_K, block_N):
169
- B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
156
+ # Copy tile of B
157
+ T.copy(B[ko * block_K, bx * block_N], B_shared)
170
158
171
159
# Perform a tile-level GEMM on the shared buffers
172
160
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
173
161
T.gemm(A_shared, B_shared, C_local)
162
+
163
+ # relu
164
+ for i, j in T.Parallel(block_M, block_N):
165
+ C_local[i, j] = T.max(C_local[i, j], 0 )
174
166
175
167
# Copy result back to global memory
176
168
T.copy(C_local, C[by * block_M, bx * block_N])
177
169
178
- return main
170
+ return matmul_relu_kernel
179
171
180
172
181
- # 1. Define the kernel (matmul) with the desired dimensions
182
- func = matmul(1024 , 1024 , 1024 , 128 , 128 , 32 )
173
+ M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
174
+ N = 1024
175
+ K = 1024
176
+ block_M = 128
177
+ block_N = 128
178
+ block_K = 32
183
179
184
- # 2. Compile the kernel into a torch function
185
- # out_idx specifies the index of the output buffer in the argument list
186
- # if out_idx is specified, the tensor will be created during runtime
187
- # target currently can be "cuda" or "hip" or "cpu".
188
- jit_kernel = tilelang.compile(func, out_idx = [2 ], target = " cuda" )
180
+ # 1. Define the kernel (matmul) and compile/lower it into an executable module
181
+ matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
189
182
190
183
# 3. Test the kernel in Python with PyTorch data
191
184
import torch
192
185
193
186
# Create random input tensors on the GPU
194
- a = torch.randn(1024 , 1024 , device = " cuda" , dtype = torch.float16)
195
- b = torch.randn(1024 , 1024 , device = " cuda" , dtype = torch.float16)
196
-
187
+ a = torch.randn(M, K , device = " cuda" , dtype = torch.float16)
188
+ b = torch.randn(K, N , device = " cuda" , dtype = torch.float16)
189
+ c = torch.empty(M, N, device = " cuda " , dtype = torch.float16)
197
190
198
- # Run the kernel through the JIT-compiled function
199
- c = jit_kernel (a, b)
191
+ # Run the kernel through the Profiler
192
+ matmul_relu_kernel (a, b, c )
200
193
194
+ print (c)
201
195
# Reference multiplication using PyTorch
202
- ref_c = a @ b
196
+ ref_c = torch.relu( a @ b)
203
197
204
198
# Validate correctness
205
199
torch.testing.assert_close(c, ref_c, rtol = 1e-2 , atol = 1e-2 )
206
200
print (" Kernel output matches PyTorch reference." )
207
201
208
202
# 4. Retrieve and inspect the generated CUDA source (optional)
209
- cuda_source = jit_kernel.get_kernel_source()
210
- print (" Generated CUDA kernel:\n " , cuda_source)
203
+ # cuda_source = jit_kernel.get_kernel_source()
204
+ # print("Generated CUDA kernel:\n", cuda_source)
211
205
212
- # 5.Pofile latency with the profiler
213
- profiler = jit_kernel .get_profiler()
206
+ # 5.Profile latency with kernel
207
+ profiler = matmul_relu_kernel .get_profiler(tensor_supply_type = tilelang.TensorSupplyType.Normal )
214
208
215
209
latency = profiler.do_bench()
216
210
0 commit comments