Skip to content

Commit c89e4d7

Browse files
authored
Use turbine.runtime for generate_iree_ref (iree-org#861)
`generate_iree_ref` only used to get reference values for Wave tests and it also the last user of `inplace=False` flag in Wave. Signed-off-by: Ivan Butygin <[email protected]>
1 parent b3c3821 commit c89e4d7

File tree

7 files changed

+59
-51
lines changed

7 files changed

+59
-51
lines changed

iree/turbine/kernel/wave/iree_utils.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
# Copyright 2024 The IREE Authors
1+
# Copyright 2025 The IREE Authors
22
#
33
# Licensed under the Apache License v2.0 with LLVM Exceptions.
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import torch
8-
from typing import Any
9-
from .utils.run_utils import compile_and_invoke
108
from ...support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM
11-
from .compile import WaveCompileOptions
9+
from iree.turbine.runtime.launch import Launchable
1210

1311

1412
def get_chain_mmt_asm(
1513
query_type: str, key_type: str, value_type: str, output_type: str
16-
) -> str:
14+
) -> tuple[str, str]:
1715
B, M, K1, input_dtype = query_type.split("x")
1816
B, K2, K1, input_dtype = key_type.split("x")
1917
B, N, K2, input_dtype = value_type.split("x")
@@ -22,7 +20,8 @@ def get_chain_mmt_asm(
2220
intermediate_cast_type = f"{B}x{K2}x{M}x{input_dtype}"
2321
transposed_cast_type = f"{B}x{M}x{K2}x{input_dtype}"
2422
transposed_output_type = f"{B}x{M}x{N}x{output_dtype}"
25-
return f"""
23+
return (
24+
f"""
2625
func.func @chain_mmt(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{
2726
%c0 = arith.constant 0.0 : f32
2827
%init = tensor.empty() : tensor<{intermediate_output_type}>
@@ -39,12 +38,14 @@ def get_chain_mmt_asm(
3938
%init4 = tensor.empty() : tensor<{output_type}>
4039
%transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1]
4140
return %transpose2 : tensor<{output_type}>
42-
}}"""
41+
}}""",
42+
"chain_mmt",
43+
)
4344

4445

4546
def get_chain_mmt_f8_asm(
4647
query_type: str, key_type: str, value_type: str, output_type: str
47-
) -> str:
48+
) -> tuple[str, str]:
4849
B, M, K1, input_dtype = query_type.split("x")
4950
B, K2, K1, input_dtype = key_type.split("x")
5051
B, N, K2, input_dtype = value_type.split("x")
@@ -57,7 +58,8 @@ def get_chain_mmt_f8_asm(
5758
query_f8_type = "x".join([B, M, K1, f8_dtype])
5859
key_f8_type = "x".join([B, K2, K1, f8_dtype])
5960
value_f8_type = "x".join([B, N, K2, f8_dtype])
60-
return f"""
61+
return (
62+
f"""
6163
func.func @chain_mmt_f8(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{
6264
%c0 = arith.constant 0.0 : f32
6365
%init = tensor.empty() : tensor<{intermediate_output_type}>
@@ -77,7 +79,9 @@ def get_chain_mmt_f8_asm(
7779
%init4 = tensor.empty() : tensor<{output_type}>
7880
%transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1]
7981
return %transpose2 : tensor<{output_type}>
80-
}}"""
82+
}}""",
83+
"chain_mmt_f8",
84+
)
8185

8286

8387
def get_mmt_asm(
@@ -86,7 +90,7 @@ def get_mmt_asm(
8690
acc_type: str,
8791
batch: bool = False,
8892
cast_fp8: bool = False,
89-
) -> str:
93+
) -> tuple[str, str]:
9094
acc_dtype = acc_type.split("x")[-1]
9195
operator = "batch_matmul_transpose_b" if batch else "matmul_transpose_b"
9296
func_name = "bmmt" if batch else "mmt"
@@ -118,14 +122,15 @@ def get_mmt_asm(
118122
outs(%inital_result: tensor<{acc_type}>) -> tensor<{acc_type}>
119123
return %result : tensor<{acc_type}>
120124
}}"""
121-
return matmul_function
125+
return matmul_function, func_name
122126

123127

124128
def get_conv_asm(
125129
conv_type: str, lhs_type: str, rhs_type: str, res_type: str, stride: int
126-
) -> str:
130+
) -> tuple[str, str]:
127131
res_dtype = res_type.split("x")[-1]
128-
return f"""
132+
return (
133+
f"""
129134
func.func @conv_{conv_type}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{res_type}> {{
130135
%c0 = arith.constant 0.0 : {res_dtype}
131136
%init = tensor.empty() : tensor<{res_type}>
@@ -135,7 +140,9 @@ def get_conv_asm(
135140
ins(%lhs, %rhs : tensor<{lhs_type}>, tensor<{rhs_type}>)
136141
outs(%inital_result : tensor<{res_type}>) -> tensor<{res_type}>
137142
return %result : tensor<{res_type}>
138-
}}"""
143+
}}""",
144+
f"conv_{conv_type}",
145+
)
139146

140147

141148
def dtype_str(dtype: torch.dtype) -> str:
@@ -153,7 +160,6 @@ def generate_iree_ref(
153160
kernel_type: str,
154161
kernel_inputs: list[torch.Tensor],
155162
kernel_outputs: list[torch.Tensor],
156-
options: WaveCompileOptions,
157163
):
158164
"""
159165
Generate a reference output for the given kernel type and arguments.
@@ -165,7 +171,7 @@ def generate_iree_ref(
165171
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
166172
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
167173
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
168-
asm = get_mmt_asm(
174+
asm, func_name = get_mmt_asm(
169175
lhs_type,
170176
rhs_type,
171177
acc_type,
@@ -176,37 +182,38 @@ def generate_iree_ref(
176182
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
177183
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
178184
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
179-
asm = get_mmt_asm(lhs_type, rhs_type, acc_type, batch=True)
185+
asm, func_name = get_mmt_asm(lhs_type, rhs_type, acc_type, batch=True)
180186
elif kernel_type == "chain_mmt":
181187
query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
182188
key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
183189
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
184190
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
185-
asm = get_chain_mmt_asm(query_type, key_type, value_type, output_type)
191+
asm, func_name = get_chain_mmt_asm(
192+
query_type, key_type, value_type, output_type
193+
)
186194
elif kernel_type == "chain_mmt_f8":
187195
query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
188196
key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
189197
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
190198
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
191-
asm = get_chain_mmt_f8_asm(query_type, key_type, value_type, output_type)
199+
asm, func_name = get_chain_mmt_f8_asm(
200+
query_type, key_type, value_type, output_type
201+
)
192202
elif kernel_type.startswith(conv_str):
193203
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
194204
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
195205
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
196206
conv_type = kernel_type[len(conv_str) :]
197-
asm = get_conv_asm(
207+
asm, func_name = get_conv_asm(
198208
conv_type, lhs_type, rhs_type, acc_type, int(kwargs["stride"])
199209
)
200210
else:
201211
raise ValueError(f"Unknown kernel type: {kernel_type}")
202212

203-
options.func_name = kernel_type
204-
options.inplace = False
205-
options.kernel_hash = None
206-
options.dynamic_symbols_map = {}
207-
compile_and_invoke(
208-
asm,
209-
kernel_inputs,
210-
kernel_outputs,
211-
options,
212-
)
213+
launchable = Launchable.jit_compile(asm, entry_point=func_name)
214+
res = launchable(*kernel_inputs, outputs=kernel_outputs)
215+
if len(kernel_outputs) == 1:
216+
kernel_outputs[0][:] = res
217+
else:
218+
for r, k in zip(res, kernel_outputs):
219+
k[:] = r

iree/turbine/kernel/wave/nn/linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import iree.turbine.kernel.lang as tkl
1313
import iree.turbine.kernel.wave as tkw
1414
from iree.turbine.kernel.lang.global_symbols import *
15-
from iree.turbine.kernel.wave.iree_utils import generate_iree_ref
1615
from iree.turbine.kernel.wave.utils.general_utils import (
1716
get_default_scheduling_params,
1817
)

iree/turbine/runtime/launch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ def _resolve_target_binary(self, turbine_device: Device) -> _TargetBinary:
199199
f"Could not load a target binary for device {turbine_device}"
200200
)
201201

202-
def __call__(self, *args, device: Optional[torch.device] = None):
202+
def __call__(
203+
self, *args, device: Optional[torch.device] = None, outputs: Sequence[Any] = ()
204+
):
203205
turbine_device: Optional[Device] = (
204206
None if device is None else get_device_from_torch(device)
205207
)
@@ -238,7 +240,13 @@ def __call__(self, *args, device: Optional[torch.device] = None):
238240

239241
vm_context, vm_function = self._resolve_target_binary(turbine_device)
240242

241-
ret_list = VmVariantList(1)
243+
ret_list = VmVariantList(len(outputs))
244+
for output in outputs:
245+
if isinstance(output, Tensor):
246+
assert output.is_contiguous(), "Outputs must be contiguous"
247+
ret_list.push_ref(turbine_device.import_torch_tensor(arg))
248+
else:
249+
raise ValueError(f"Unsupported output type: {type(output)}")
242250

243251
invoke_vm_function(
244252
turbine_device, self._is_async, vm_context, vm_function, arg_list, ret_list

tests/kernel/wave/attention/chained_gemm_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def repeat(
171171
print(f"IR dumped to {filename}")
172172

173173
iree_ref = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
174-
generate_iree_ref("chain_mmt", [q, k, v], [iree_ref], options)
174+
generate_iree_ref("chain_mmt", [q, k, v], [iree_ref])
175175
assert_close(output, iree_ref, check_device=False, atol=0, rtol=0)
176176

177177
torch_qk = torch.matmul(q, k.transpose(-1, -2))
@@ -322,5 +322,5 @@ def repeat(
322322
f.write(asm)
323323

324324
iree_ref = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
325-
generate_iree_ref("chain_mmt_f8", [q, k, v], [iree_ref], options)
325+
generate_iree_ref("chain_mmt_f8", [q, k, v], [iree_ref])
326326
assert_close(output, iree_ref, atol=7e-5, rtol=2e-3, check_device=False)

tests/kernel/wave/runtime/cache_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
1212

1313
import copy
14-
import tempfile
1514
import pytest
1615
import torch
1716
from torch.testing import assert_close
@@ -27,10 +26,8 @@
2726
reset_cache_manager,
2827
)
2928
from iree.turbine.kernel.lang.global_symbols import *
30-
from iree.turbine.kernel.wave.iree_utils import generate_iree_ref
3129
from iree.turbine.kernel.wave.utils.run_utils import (
3230
set_default_run_config,
33-
get_default_arch,
3431
)
3532
from iree.turbine.kernel.wave.utils.general_utils import (
3633
get_default_scheduling_params,

tests/kernel/wave/wave_e2e_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,9 +1557,6 @@ def test_igemm_conv(
15571557
"conv_2d_" + layout,
15581558
[x, we],
15591559
[iree_ref],
1560-
options,
1561-
stride=stride,
1562-
run_bench=True,
15631560
)
15641561

15651562

tests/kernel/wave/wave_gemm_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
210210
dump_perf, "iree_" + perf_filename
211211
)
212212
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
213-
generate_iree_ref("mmt", [a, b], [iree_ref], options)
213+
generate_iree_ref("mmt", [a, b], [iree_ref])
214214
assert_close(c, iree_ref, check_device=False)
215215

216216

@@ -328,7 +328,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
328328
dump_perf, "iree_" + perf_filename
329329
)
330330
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
331-
generate_iree_ref("mmt", [a, b], [iree_ref], options)
331+
generate_iree_ref("mmt", [a, b], [iree_ref])
332332
assert_close(c, iree_ref, check_device=False)
333333

334334

@@ -478,7 +478,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
478478
dump_perf, "iree_" + perf_filename
479479
)
480480
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
481-
generate_iree_ref("mmt", [a, b], [iree_ref], options)
481+
generate_iree_ref("mmt", [a, b], [iree_ref])
482482
assert_close(c, iree_ref, check_device=False, atol=1e-3, rtol=1e-3)
483483

484484

@@ -621,7 +621,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
621621
dump_perf, "iree_" + request.node.name + ".json"
622622
)
623623
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
624-
generate_iree_ref("mmt", [a, b], [iree_ref], options)
624+
generate_iree_ref("mmt", [a, b], [iree_ref])
625625
assert_close(c, iree_ref, atol=2e-4, rtol=3e-4, check_device=False)
626626

627627

@@ -766,7 +766,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
766766
dump_perf, "iree_" + request.node.name + ".json"
767767
)
768768
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.int32)
769-
generate_iree_ref("mmt", [a, b], [iree_ref], options)
769+
generate_iree_ref("mmt", [a, b], [iree_ref])
770770
assert_close(c, iree_ref, check_device=False)
771771

772772

@@ -879,7 +879,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
879879
dump_perf, "iree_" + request.node.name + ".json"
880880
)
881881
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.int32)
882-
generate_iree_ref("mmt", [a, b], [iree_ref], options)
882+
generate_iree_ref("mmt", [a, b], [iree_ref])
883883
assert_close(c, iree_ref, check_device=False)
884884

885885

@@ -989,7 +989,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
989989
dump_perf, "iree_" + request.node.name + ".json"
990990
)
991991
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
992-
generate_iree_ref("mmt_f8", [a, b], [iree_ref], options)
992+
generate_iree_ref("mmt_f8", [a, b], [iree_ref])
993993
assert_close(c, iree_ref, atol=3e-5, rtol=3e-4, check_device=False)
994994

995995

@@ -1094,7 +1094,7 @@ def repeat(
10941094
dump_perf, "iree_" + request.node.name + ".json"
10951095
)
10961096
iree_ref = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
1097-
generate_iree_ref("bmmt", [a, b], [iree_ref], options)
1097+
generate_iree_ref("bmmt", [a, b], [iree_ref])
10981098
assert_close(c, iree_ref, check_device=False)
10991099

11001100
torch_ref = torch.matmul(a, b.transpose(-2, -1))
@@ -1206,7 +1206,7 @@ def repeat(
12061206
dump_perf, "iree_" + request.node.name + ".json"
12071207
)
12081208
iree_ref = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
1209-
generate_iree_ref("bmmt", [a, b], [iree_ref], options)
1209+
generate_iree_ref("bmmt", [a, b], [iree_ref])
12101210
assert_close(c, iree_ref, check_device=False)
12111211

12121212
torch_ref = torch.matmul(a, b.transpose(-2, -1))

0 commit comments

Comments
 (0)