1- # Copyright 2024 The IREE Authors
1+ # Copyright 2025 The IREE Authors
22#
33# Licensed under the Apache License v2.0 with LLVM Exceptions.
44# See https://llvm.org/LICENSE.txt for license information.
55# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
77import torch
8- from typing import Any
9- from .utils .run_utils import compile_and_invoke
108from ...support .conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM
11- from . compile import WaveCompileOptions
9+ from iree . turbine . runtime . launch import Launchable
1210
1311
1412def get_chain_mmt_asm (
1513 query_type : str , key_type : str , value_type : str , output_type : str
16- ) -> str :
14+ ) -> tuple [ str , str ] :
1715 B , M , K1 , input_dtype = query_type .split ("x" )
1816 B , K2 , K1 , input_dtype = key_type .split ("x" )
1917 B , N , K2 , input_dtype = value_type .split ("x" )
@@ -22,7 +20,8 @@ def get_chain_mmt_asm(
2220 intermediate_cast_type = f"{ B } x{ K2 } x{ M } x{ input_dtype } "
2321 transposed_cast_type = f"{ B } x{ M } x{ K2 } x{ input_dtype } "
2422 transposed_output_type = f"{ B } x{ M } x{ N } x{ output_dtype } "
25- return f"""
23+ return (
24+ f"""
2625 func.func @chain_mmt(%query: tensor<{ query_type } >, %key: tensor<{ key_type } >, %value: tensor<{ value_type } >) -> tensor<{ output_type } > {{
2726 %c0 = arith.constant 0.0 : f32
2827 %init = tensor.empty() : tensor<{ intermediate_output_type } >
@@ -39,12 +38,14 @@ def get_chain_mmt_asm(
3938 %init4 = tensor.empty() : tensor<{ output_type } >
4039 %transpose2 = linalg.transpose ins(%result2: tensor<{ transposed_output_type } >) outs(%init4: tensor<{ output_type } >) permutation=[0, 2, 1]
4140 return %transpose2 : tensor<{ output_type } >
42- }}"""
41+ }}""" ,
42+ "chain_mmt" ,
43+ )
4344
4445
4546def get_chain_mmt_f8_asm (
4647 query_type : str , key_type : str , value_type : str , output_type : str
47- ) -> str :
48+ ) -> tuple [ str , str ] :
4849 B , M , K1 , input_dtype = query_type .split ("x" )
4950 B , K2 , K1 , input_dtype = key_type .split ("x" )
5051 B , N , K2 , input_dtype = value_type .split ("x" )
@@ -57,7 +58,8 @@ def get_chain_mmt_f8_asm(
5758 query_f8_type = "x" .join ([B , M , K1 , f8_dtype ])
5859 key_f8_type = "x" .join ([B , K2 , K1 , f8_dtype ])
5960 value_f8_type = "x" .join ([B , N , K2 , f8_dtype ])
60- return f"""
61+ return (
62+ f"""
6163 func.func @chain_mmt_f8(%query: tensor<{ query_type } >, %key: tensor<{ key_type } >, %value: tensor<{ value_type } >) -> tensor<{ output_type } > {{
6264 %c0 = arith.constant 0.0 : f32
6365 %init = tensor.empty() : tensor<{ intermediate_output_type } >
@@ -77,7 +79,9 @@ def get_chain_mmt_f8_asm(
7779 %init4 = tensor.empty() : tensor<{ output_type } >
7880 %transpose2 = linalg.transpose ins(%result2: tensor<{ transposed_output_type } >) outs(%init4: tensor<{ output_type } >) permutation=[0, 2, 1]
7981 return %transpose2 : tensor<{ output_type } >
80- }}"""
82+ }}""" ,
83+ "chain_mmt_f8" ,
84+ )
8185
8286
8387def get_mmt_asm (
@@ -86,7 +90,7 @@ def get_mmt_asm(
8690 acc_type : str ,
8791 batch : bool = False ,
8892 cast_fp8 : bool = False ,
89- ) -> str :
93+ ) -> tuple [ str , str ] :
9094 acc_dtype = acc_type .split ("x" )[- 1 ]
9195 operator = "batch_matmul_transpose_b" if batch else "matmul_transpose_b"
9296 func_name = "bmmt" if batch else "mmt"
@@ -118,14 +122,15 @@ def get_mmt_asm(
118122 outs(%inital_result: tensor<{ acc_type } >) -> tensor<{ acc_type } >
119123 return %result : tensor<{ acc_type } >
120124 }}"""
121- return matmul_function
125+ return matmul_function , func_name
122126
123127
124128def get_conv_asm (
125129 conv_type : str , lhs_type : str , rhs_type : str , res_type : str , stride : int
126- ) -> str :
130+ ) -> tuple [ str , str ] :
127131 res_dtype = res_type .split ("x" )[- 1 ]
128- return f"""
132+ return (
133+ f"""
129134 func.func @conv_{ conv_type } (%lhs: tensor<{ lhs_type } >, %rhs: tensor<{ rhs_type } >) -> tensor<{ res_type } > {{
130135 %c0 = arith.constant 0.0 : { res_dtype }
131136 %init = tensor.empty() : tensor<{ res_type } >
@@ -135,7 +140,9 @@ def get_conv_asm(
135140 ins(%lhs, %rhs : tensor<{ lhs_type } >, tensor<{ rhs_type } >)
136141 outs(%inital_result : tensor<{ res_type } >) -> tensor<{ res_type } >
137142 return %result : tensor<{ res_type } >
138- }}"""
143+ }}""" ,
144+ f"conv_{ conv_type } " ,
145+ )
139146
140147
141148def dtype_str (dtype : torch .dtype ) -> str :
@@ -153,7 +160,6 @@ def generate_iree_ref(
153160 kernel_type : str ,
154161 kernel_inputs : list [torch .Tensor ],
155162 kernel_outputs : list [torch .Tensor ],
156- options : WaveCompileOptions ,
157163):
158164 """
159165 Generate a reference output for the given kernel type and arguments.
@@ -165,7 +171,7 @@ def generate_iree_ref(
165171 lhs_type = get_type_str (kernel_inputs [0 ].shape , kernel_inputs [0 ].dtype )
166172 rhs_type = get_type_str (kernel_inputs [1 ].shape , kernel_inputs [1 ].dtype )
167173 acc_type = get_type_str (kernel_outputs [0 ].shape , kernel_outputs [0 ].dtype )
168- asm = get_mmt_asm (
174+ asm , func_name = get_mmt_asm (
169175 lhs_type ,
170176 rhs_type ,
171177 acc_type ,
@@ -176,37 +182,38 @@ def generate_iree_ref(
176182 lhs_type = get_type_str (kernel_inputs [0 ].shape , kernel_inputs [0 ].dtype )
177183 rhs_type = get_type_str (kernel_inputs [1 ].shape , kernel_inputs [1 ].dtype )
178184 acc_type = get_type_str (kernel_outputs [0 ].shape , kernel_outputs [0 ].dtype )
179- asm = get_mmt_asm (lhs_type , rhs_type , acc_type , batch = True )
185+ asm , func_name = get_mmt_asm (lhs_type , rhs_type , acc_type , batch = True )
180186 elif kernel_type == "chain_mmt" :
181187 query_type = get_type_str (kernel_inputs [0 ].shape , kernel_inputs [0 ].dtype )
182188 key_type = get_type_str (kernel_inputs [1 ].shape , kernel_inputs [1 ].dtype )
183189 value_type = get_type_str (kernel_inputs [2 ].shape , kernel_inputs [2 ].dtype )
184190 output_type = get_type_str (kernel_outputs [0 ].shape , kernel_outputs [0 ].dtype )
185- asm = get_chain_mmt_asm (query_type , key_type , value_type , output_type )
191+ asm , func_name = get_chain_mmt_asm (
192+ query_type , key_type , value_type , output_type
193+ )
186194 elif kernel_type == "chain_mmt_f8" :
187195 query_type = get_type_str (kernel_inputs [0 ].shape , kernel_inputs [0 ].dtype )
188196 key_type = get_type_str (kernel_inputs [1 ].shape , kernel_inputs [1 ].dtype )
189197 value_type = get_type_str (kernel_inputs [2 ].shape , kernel_inputs [2 ].dtype )
190198 output_type = get_type_str (kernel_outputs [0 ].shape , kernel_outputs [0 ].dtype )
191- asm = get_chain_mmt_f8_asm (query_type , key_type , value_type , output_type )
199+ asm , func_name = get_chain_mmt_f8_asm (
200+ query_type , key_type , value_type , output_type
201+ )
192202 elif kernel_type .startswith (conv_str ):
193203 lhs_type = get_type_str (kernel_inputs [0 ].shape , kernel_inputs [0 ].dtype )
194204 rhs_type = get_type_str (kernel_inputs [1 ].shape , kernel_inputs [1 ].dtype )
195205 acc_type = get_type_str (kernel_outputs [0 ].shape , kernel_outputs [0 ].dtype )
196206 conv_type = kernel_type [len (conv_str ) :]
197- asm = get_conv_asm (
207+ asm , func_name = get_conv_asm (
198208 conv_type , lhs_type , rhs_type , acc_type , int (kwargs ["stride" ])
199209 )
200210 else :
201211 raise ValueError (f"Unknown kernel type: { kernel_type } " )
202212
203- options .func_name = kernel_type
204- options .inplace = False
205- options .kernel_hash = None
206- options .dynamic_symbols_map = {}
207- compile_and_invoke (
208- asm ,
209- kernel_inputs ,
210- kernel_outputs ,
211- options ,
212- )
213+ launchable = Launchable .jit_compile (asm , entry_point = func_name )
214+ res = launchable (* kernel_inputs , outputs = kernel_outputs )
215+ if len (kernel_outputs ) == 1 :
216+ kernel_outputs [0 ][:] = res
217+ else :
218+ for r , k in zip (res , kernel_outputs ):
219+ k [:] = r
0 commit comments