@@ -214,6 +214,145 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
214214 assert_close (c , iree_ref , check_device = False )
215215
216216
217+ @require_e2e
218+ @pytest .mark .parametrize ("shape" , get_test_shapes ("test_gemm" ))
219+ @pytest .mark .parametrize (
220+ "enable_scheduling" ,
221+ [SchedulingType .NONE , SchedulingType .PREFETCH , SchedulingType .MODULO ],
222+ )
223+ @param_bool ("dynamic_dims" , "dyn" )
224+ @pytest .mark .parametrize (
225+ "mfma_variant" ,
226+ [
227+ MMAType .F32_16x16x16_F16 ,
228+ MMAType .F32_32x32x8_F16 ,
229+ ],
230+ )
231+ def testNonTransposeGemm (
232+ shape : tuple [int ],
233+ enable_scheduling : SchedulingType ,
234+ dynamic_dims : bool ,
235+ mfma_variant : MMAType ,
236+ request ,
237+ ):
238+ run_bench = request .config .getoption ("--runperf" )
239+ dump_perf = request .config .getoption ("--dump-perf-files-path" )
240+ # Input sizes
241+ M = tkl .sym .M
242+ N = tkl .sym .N
243+ K = tkl .sym .K
244+ # Workgroup tile sizes
245+ BLOCK_M = tkl .sym .BLOCK_M
246+ BLOCK_N = tkl .sym .BLOCK_N
247+ BLOCK_K = tkl .sym .BLOCK_K
248+ # Address space (for GPU, shared(1) or global(0))
249+ ADDRESS_SPACE = tkl .sym .ADDRESS_SPACE
250+
251+ # Expose user-constraints
252+ constraints : list [tkw .Constraint ] = [tkw .WorkgroupConstraint (M , BLOCK_M , 0 )]
253+ constraints += [tkw .WorkgroupConstraint (N , BLOCK_N , 1 )]
254+ constraints += [tkw .TilingConstraint (K , BLOCK_K )]
255+ constraints += [tkw .WaveConstraint (M , BLOCK_M / 2 )]
256+ constraints += [tkw .WaveConstraint (N , BLOCK_N / 2 )]
257+
258+ constraints += [
259+ tkw .HardwareConstraint (
260+ threads_per_wave = 64 , waves_per_block = (2 , 2 , 1 ), mma_type = mfma_variant
261+ )
262+ ]
263+
264+ if dynamic_dims :
265+ constraints += [tkw .Assumption (K > BLOCK_K * 4 )]
266+
267+ i = tkw .IndexMapping .iterator (0 )
268+ j = tkw .IndexMapping .iterator (1 )
269+ # Transpose during read for expected shape: (M, K) @ (N, K) -> (M, N)
270+ b_mapping = tkw .IndexMapping (
271+ num_iterators = 2 , inputs = {N : i , K : j }, outputs = {N : i , K : j }
272+ )
273+
274+ @tkw .wave (constraints )
275+ def gemm (
276+ a : tkl .Memory [M , K , ADDRESS_SPACE , tkl .f16 ],
277+ b : tkl .Memory [K , N , ADDRESS_SPACE , tkl .f16 ],
278+ c : tkl .Memory [M , N , GLOBAL_ADDRESS_SPACE , tkl .f32 ],
279+ ):
280+ c_reg = tkl .Register [M , N , tkl .f32 ](0.0 )
281+
282+ @tkw .iterate (K , init_args = [c_reg ])
283+ def repeat (acc : tkl .Register [M , N , tkl .f32 ]) -> tkl .Register [M , N , tkl .f32 ]:
284+ # a_reg: tkw.Register[M, K, tkl.f16]
285+ a_reg = tkw .read (a )
286+ # b_reg: tkw.Register[N, K, tkl.f16]; data is transposed [K, N] -> [N, K] from b_mapping
287+ b_reg = tkw .read (b , mapping = b_mapping )
288+ acc = tkw .mma (a_reg , b_reg , acc )
289+ return acc
290+
291+ tkw .write (repeat , c )
292+
293+ hyperparams = {
294+ ADDRESS_SPACE : SHARED_ADDRESS_SPACE ,
295+ BLOCK_M : 64 ,
296+ BLOCK_N : 64 ,
297+ BLOCK_K : 32 ,
298+ M : shape [0 ],
299+ N : shape [1 ],
300+ K : shape [2 ],
301+ }
302+ hyperparams .update (get_default_scheduling_params ())
303+
304+ dynamic_symbols = []
305+ dynamic_symbols_map = {}
306+ if dynamic_dims :
307+ dynamic_symbols_map [M ] = hyperparams [M ]
308+ dynamic_symbols_map [N ] = hyperparams [N ]
309+ dynamic_symbols_map [K ] = hyperparams [K ]
310+ dynamic_symbols .append (M )
311+ dynamic_symbols .append (N )
312+ dynamic_symbols .append (K )
313+ del hyperparams [M ]
314+ del hyperparams [N ]
315+ del hyperparams [K ]
316+
317+ perf_filename = request .node .name + ".json"
318+ options = WaveCompileOptions (
319+ subs = hyperparams ,
320+ canonicalize = True ,
321+ run_bench = run_bench ,
322+ schedule = enable_scheduling ,
323+ use_scheduling_barriers = enable_scheduling_barriers ,
324+ dynamic_symbols = dynamic_symbols ,
325+ dynamic_symbols_map = dynamic_symbols_map ,
326+ benchmark_batch_size = 10 ,
327+ benchmark_repetitions = 3 ,
328+ benchmark_results_file = (
329+ os .path .join (dump_perf , "tk_" + perf_filename ) if dump_perf else None
330+ ),
331+ )
332+ options = set_default_run_config (options )
333+ gemm = wave_compile (options , gemm )
334+ a = device_randn (shape [0 ], shape [2 ], dtype = torch .float16 )
335+ b = device_randn (shape [2 ], shape [1 ], dtype = torch .float16 )
336+ c = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
337+ asm = gemm (a , b , c )
338+
339+ if dump_generated_mlir :
340+ filename = f"wave_gemm_{ 'x' .join (map (str , shape ))} .mlir"
341+ with open (filename , "w" ) as f :
342+ f .write (asm )
343+
344+ if run_bench :
345+ if dump_perf is not None :
346+ options .benchmark_results_file = os .path .join (
347+ dump_perf , "iree_" + perf_filename
348+ )
349+ # TODO: switch to comparison against generated iree_ref
350+ torch_ref = torch .matmul (a , b )
351+ assert_close (
352+ c .to (torch .float16 ), torch_ref , atol = 1e-2 , rtol = 1e-2 , check_device = False
353+ )
354+
355+
217356@require_e2e
218357@pytest .mark .parametrize ("shape" , [(4096 , 4096 , 4096 )])
219358@pytest .mark .parametrize (
0 commit comments