Skip to content

Commit 7dede34

Browse files
efricxintin
authored andcommitted
[Wave] Implement C = A @ B GEMM example (iree-org#881)
Signed-off-by: xintin <[email protected]>
1 parent 0677c69 commit 7dede34

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

tests/kernel/wave/wave_gemm_test.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,145 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
214214
assert_close(c, iree_ref, check_device=False)
215215

216216

217+
@require_e2e
218+
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
219+
@pytest.mark.parametrize(
220+
"enable_scheduling",
221+
[SchedulingType.NONE, SchedulingType.PREFETCH, SchedulingType.MODULO],
222+
)
223+
@param_bool("dynamic_dims", "dyn")
224+
@pytest.mark.parametrize(
225+
"mfma_variant",
226+
[
227+
MMAType.F32_16x16x16_F16,
228+
MMAType.F32_32x32x8_F16,
229+
],
230+
)
231+
def testNonTransposeGemm(
232+
shape: tuple[int],
233+
enable_scheduling: SchedulingType,
234+
dynamic_dims: bool,
235+
mfma_variant: MMAType,
236+
request,
237+
):
238+
run_bench = request.config.getoption("--runperf")
239+
dump_perf = request.config.getoption("--dump-perf-files-path")
240+
# Input sizes
241+
M = tkl.sym.M
242+
N = tkl.sym.N
243+
K = tkl.sym.K
244+
# Workgroup tile sizes
245+
BLOCK_M = tkl.sym.BLOCK_M
246+
BLOCK_N = tkl.sym.BLOCK_N
247+
BLOCK_K = tkl.sym.BLOCK_K
248+
# Address space (for GPU, shared(1) or global(0))
249+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
250+
251+
# Expose user-constraints
252+
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
253+
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
254+
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
255+
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
256+
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
257+
258+
constraints += [
259+
tkw.HardwareConstraint(
260+
threads_per_wave=64, waves_per_block=(2, 2, 1), mma_type=mfma_variant
261+
)
262+
]
263+
264+
if dynamic_dims:
265+
constraints += [tkw.Assumption(K > BLOCK_K * 4)]
266+
267+
i = tkw.IndexMapping.iterator(0)
268+
j = tkw.IndexMapping.iterator(1)
269+
# Transpose during read for expected shape: (M, K) @ (N, K) -> (M, N)
270+
b_mapping = tkw.IndexMapping(
271+
num_iterators=2, inputs={N: i, K: j}, outputs={N: i, K: j}
272+
)
273+
274+
@tkw.wave(constraints)
275+
def gemm(
276+
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
277+
b: tkl.Memory[K, N, ADDRESS_SPACE, tkl.f16],
278+
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
279+
):
280+
c_reg = tkl.Register[M, N, tkl.f32](0.0)
281+
282+
@tkw.iterate(K, init_args=[c_reg])
283+
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
284+
# a_reg: tkw.Register[M, K, tkl.f16]
285+
a_reg = tkw.read(a)
286+
# b_reg: tkw.Register[N, K, tkl.f16]; data is transposed [K, N] -> [N, K] from b_mapping
287+
b_reg = tkw.read(b, mapping=b_mapping)
288+
acc = tkw.mma(a_reg, b_reg, acc)
289+
return acc
290+
291+
tkw.write(repeat, c)
292+
293+
hyperparams = {
294+
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
295+
BLOCK_M: 64,
296+
BLOCK_N: 64,
297+
BLOCK_K: 32,
298+
M: shape[0],
299+
N: shape[1],
300+
K: shape[2],
301+
}
302+
hyperparams.update(get_default_scheduling_params())
303+
304+
dynamic_symbols = []
305+
dynamic_symbols_map = {}
306+
if dynamic_dims:
307+
dynamic_symbols_map[M] = hyperparams[M]
308+
dynamic_symbols_map[N] = hyperparams[N]
309+
dynamic_symbols_map[K] = hyperparams[K]
310+
dynamic_symbols.append(M)
311+
dynamic_symbols.append(N)
312+
dynamic_symbols.append(K)
313+
del hyperparams[M]
314+
del hyperparams[N]
315+
del hyperparams[K]
316+
317+
perf_filename = request.node.name + ".json"
318+
options = WaveCompileOptions(
319+
subs=hyperparams,
320+
canonicalize=True,
321+
run_bench=run_bench,
322+
schedule=enable_scheduling,
323+
use_scheduling_barriers=enable_scheduling_barriers,
324+
dynamic_symbols=dynamic_symbols,
325+
dynamic_symbols_map=dynamic_symbols_map,
326+
benchmark_batch_size=10,
327+
benchmark_repetitions=3,
328+
benchmark_results_file=(
329+
os.path.join(dump_perf, "tk_" + perf_filename) if dump_perf else None
330+
),
331+
)
332+
options = set_default_run_config(options)
333+
gemm = wave_compile(options, gemm)
334+
a = device_randn(shape[0], shape[2], dtype=torch.float16)
335+
b = device_randn(shape[2], shape[1], dtype=torch.float16)
336+
c = device_zeros(shape[0], shape[1], dtype=torch.float32)
337+
asm = gemm(a, b, c)
338+
339+
if dump_generated_mlir:
340+
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
341+
with open(filename, "w") as f:
342+
f.write(asm)
343+
344+
if run_bench:
345+
if dump_perf is not None:
346+
options.benchmark_results_file = os.path.join(
347+
dump_perf, "iree_" + perf_filename
348+
)
349+
# TODO: switch to comparison against generated iree_ref
350+
torch_ref = torch.matmul(a, b)
351+
assert_close(
352+
c.to(torch.float16), torch_ref, atol=1e-2, rtol=1e-2, check_device=False
353+
)
354+
355+
217356
@require_e2e
218357
@pytest.mark.parametrize("shape", [(4096, 4096, 4096)])
219358
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)