Skip to content

Commit e2e178b

Browse files
authored
[Wave] Fix failing jupyter test (iree-org#854)
This PR fixes the failing jupyter notebook issue and adds a test for it. Signed-off-by: Harsh Menon <[email protected]>
1 parent bee5064 commit e2e178b

File tree

2 files changed

+122
-7
lines changed

2 files changed

+122
-7
lines changed

examples/jupyter/wave_gemm_example.ipynb

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@
5353
"BLOCK_N = sym.BLOCK_N\n",
5454
"BLOCK_K = sym.BLOCK_K\n",
5555
"\n",
56-
"# Define the address space for our memory\n",
57-
"ADDRESS_SPACE = sym.ADDRESS_SPACE\n",
58-
"GLOBAL_ADDRESS_SPACE = sym.GLOBAL_ADDRESS_SPACE"
56+
"# Define the address space for our memory buffers\n",
57+
"ADDRESS_SPACE_A = sym.ADDRESS_SPACE_A\n",
58+
"ADDRESS_SPACE_B = sym.ADDRESS_SPACE_B\n",
59+
"ADDRESS_SPACE_C = sym.ADDRESS_SPACE_C"
5960
]
6061
},
6162
{
@@ -94,9 +95,9 @@
9495
"\n",
9596
"@tkw.wave(constraints)\n",
9697
"def gemm(\n",
97-
" a: Memory[M, K, ADDRESS_SPACE, f16], # Input matrix A\n",
98-
" b: Memory[N, K, ADDRESS_SPACE, f16], # Input matrix B\n",
99-
" c: Memory[M, N, GLOBAL_ADDRESS_SPACE, f32], # Output matrix C\n",
98+
" a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A\n",
99+
" b: Memory[N, K, ADDRESS_SPACE_B, f16], # Input matrix B\n",
100+
" c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C\n",
100101
"):\n",
101102
" # Initialize the accumulator register with zeros\n",
102103
" c_reg = Register[M, N, f32](0.0)\n",
@@ -151,7 +152,9 @@
151152
"\n",
152153
" # Set hyperparameters for compilation\n",
153154
" hyperparams = {\n",
154-
" ADDRESS_SPACE: SHARED_ADDRESS_SPACE,\n",
155+
" ADDRESS_SPACE_A: SHARED_ADDRESS_SPACE,\n",
156+
" ADDRESS_SPACE_B: SHARED_ADDRESS_SPACE,\n",
157+
" ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,\n",
155158
" BLOCK_M: 64,\n",
156159
" BLOCK_N: 64,\n",
157160
" BLOCK_K: 32,\n",

tests/kernel/wave/jupyter_test.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from iree.turbine.kernel._support.indexing import sym
8+
from iree.turbine.kernel._support.dtype import f16, f32
9+
from iree.turbine.kernel.lang.wave_types import *
10+
from iree.turbine.kernel.lang.global_symbols import *
11+
from iree.turbine.kernel.wave.utils.run_utils import set_default_run_config
12+
import iree.turbine.kernel as tkl
13+
import iree.turbine.kernel.wave as tkw
14+
from iree.turbine.kernel.wave.compile import WaveCompileOptions, wave_compile
15+
from .common.utils import require_e2e
16+
import torch
17+
18+
# Define symbolic dimensions for our matrices
19+
M = sym.M # Rows of A and C
20+
N = sym.N # Rows of B and columns of C
21+
K = sym.K # Columns of A and B
22+
23+
# Define workgroup tile sizes
24+
BLOCK_M = sym.BLOCK_M
25+
BLOCK_N = sym.BLOCK_N
26+
BLOCK_K = sym.BLOCK_K
27+
28+
# Define the address space for our memory
29+
ADDRESS_SPACE_A = sym.ADDRESS_SPACE_A
30+
ADDRESS_SPACE_B = sym.ADDRESS_SPACE_B
31+
ADDRESS_SPACE_C = sym.ADDRESS_SPACE_C
32+
33+
# Define constraints for the kernel
34+
constraints = [
35+
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
36+
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
37+
tkw.TilingConstraint(K, BLOCK_K),
38+
tkw.WaveConstraint(M, BLOCK_M / 2),
39+
tkw.WaveConstraint(N, BLOCK_N / 2),
40+
tkw.HardwareConstraint(
41+
threads_per_wave=64,
42+
waves_per_block=(2, 2, 1),
43+
mma_type=tkw.MMAType.F32_16x16x16_F16,
44+
),
45+
]
46+
47+
48+
@tkw.wave(constraints)
49+
def gemm(
50+
a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A
51+
b: Memory[N, K, ADDRESS_SPACE_B, f16], # Input matrix B
52+
c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C
53+
):
54+
# Initialize the accumulator register with zeros
55+
c_reg = Register[M, N, f32](0.0)
56+
57+
# Iterate over the K dimension to compute the dot product
58+
@tkw.iterate(K, init_args=[c_reg])
59+
def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]:
60+
# Load elements from A and B
61+
a_reg = tkw.read(a)
62+
b_reg = tkw.read(b)
63+
64+
# Compute matrix multiplication and accumulate
65+
acc = tkw.mma(a_reg, b_reg, acc)
66+
return acc
67+
68+
# Store the final result to C
69+
tkw.write(repeat, c)
70+
71+
72+
@require_e2e
73+
def test_gemm():
74+
# Create test matrices
75+
m, n, k = 128, 256, 128 # Small dimensions for testing
76+
77+
# Initialize input matrices with random values
78+
torch.manual_seed(0)
79+
a = torch.randn(m, k, dtype=torch.float16, device="cuda")
80+
b = torch.randn(n, k, dtype=torch.float16, device="cuda")
81+
c = torch.zeros(m, n, dtype=torch.float32, device="cuda")
82+
83+
# Set hyperparameters for compilation
84+
hyperparams = {
85+
ADDRESS_SPACE_A: SHARED_ADDRESS_SPACE,
86+
ADDRESS_SPACE_B: SHARED_ADDRESS_SPACE,
87+
ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
88+
BLOCK_M: 64,
89+
BLOCK_N: 64,
90+
BLOCK_K: 32,
91+
M: m,
92+
N: n,
93+
K: k,
94+
}
95+
96+
# Compile the kernel
97+
options = WaveCompileOptions(subs=hyperparams, canonicalize=True)
98+
options = set_default_run_config(options)
99+
compiled_gemm = wave_compile(options, gemm)
100+
101+
# Run the GEMM kernel
102+
compiled_gemm(a, b, c)
103+
104+
# Verify the result using PyTorch's matmul
105+
expected = torch.matmul(a, b.t())
106+
107+
# Check if results are close (accounting for floating-point precision)
108+
assert torch.allclose(
109+
c.to(torch.float16), expected, rtol=1e-2, atol=1e-2
110+
), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}"
111+
112+
print("GEMM test passed!")

0 commit comments

Comments
 (0)