1010from typing import Callable , Optional , Any
1111import ctypes
1212from ..compile_options import WaveCompileOptions
13- from .compile_utils import compile_to_vmfb
14- from .classes import KernelLaunchInfo
1513from ..profiling import benchmark_module
1614
1715
@@ -24,12 +22,6 @@ def compute_grid(kernel_dynamic_dims: tuple[int], grid_fn: Callable):
2422 return [int (x ) for x in grid_fn (list (kernel_dynamic_dims ))]
2523
2624
27- def _read_file (name , mode ):
28- with open (name , mode ) as file :
29- data = file .read ()
30- return data
31-
32-
3325def _write_file (name , mode , data ):
3426 with open (name , mode ) as file :
3527 file .write (data )
@@ -50,37 +42,6 @@ def get_device_uuid(device_list: list[str], device_str: str) -> tuple[int, str]:
5042 return device_str
5143
5244
53- def _invoke (vm_context , device , entry_function , inputs , outputs , dynamic_dims ):
54- arg_list = rt .VmVariantList (len (inputs ) + len (dynamic_dims ))
55- ret_list = rt .VmVariantList (len (outputs ))
56-
57- for input in inputs :
58- if isinstance (input , torch .Tensor ):
59- input_cpu = input .cpu ().contiguous ()
60- device_array = rt .asdevicearray (device , input_cpu )
61- arg_list .push_ref (device_array ._buffer_view )
62- else :
63- raise ValueError (f"Unsupported input type: { type (input )} " )
64-
65- for dynamic_dim in dynamic_dims :
66- if isinstance (dynamic_dim , int ):
67- arg_list .push_int (dynamic_dim )
68- else :
69- raise ValueError (f"Unsupported dynamic dim type: { type (dynamic_dim )} " )
70-
71- vm_context .invoke (entry_function , arg_list , ret_list )
72-
73- for i , ret in enumerate (outputs ):
74- device_buffer_view = rt .HalBufferView .__iree_vm_cast__ (ret_list .get_as_ref (i ))
75- device_array = rt .DeviceArray (device , device_buffer_view )
76-
77- # TODO: Make to_host accept out array/buffer, so we can avoid extra data copy.
78- host_array = device_array .to_host ()
79-
80- # Convert to torch tensor without actually importing torch.
81- ret [:] = type (ret )(host_array )
82-
83-
8445_dl_tensor_name = ctypes .create_string_buffer (b"dltensor" )
8546_set_capsule_name = ctypes .pythonapi .PyCapsule_SetName
8647
@@ -173,14 +134,13 @@ def invoke_vmfb(
173134 options .benchmark_repetitions
174135 )
175136
176- if options .inplace :
177- # Select device as the GPU, where input tensors are coming from.
178- device_list = tuple (
179- input .device
180- for input in kernel_inputs + kernel_outputs
181- if isinstance (input , torch .Tensor )
182- )
183- device = get_device_uuid (device_list , device )
137+ # Select device as the GPU, where input tensors are coming from.
138+ device_list = tuple (
139+ input .device
140+ for input in kernel_inputs + kernel_outputs
141+ if isinstance (input , torch .Tensor )
142+ )
143+ device = get_device_uuid (device_list , device )
184144
185145 rt_config = rt .Config (device )
186146 device = rt_config .device
@@ -202,24 +162,14 @@ def invoke_vmfb(
202162 if options .kernel_hash :
203163 RUNTIME_CACHE [options .kernel_hash ] = (ctx , func )
204164
205- if options .inplace :
206- _inplace_invoke (
207- ctx .vm_context ,
208- device ,
209- func ,
210- kernel_inputs ,
211- kernel_outputs ,
212- options .dynamic_symbols_map .values (),
213- )
214- else :
215- _invoke (
216- ctx .vm_context ,
217- device ,
218- func ,
219- kernel_inputs ,
220- kernel_outputs ,
221- options .dynamic_symbols_map .values (),
222- )
165+ _inplace_invoke (
166+ ctx .vm_context ,
167+ device ,
168+ func ,
169+ kernel_inputs ,
170+ kernel_outputs ,
171+ options .dynamic_symbols_map .values (),
172+ )
223173
224174 if options .run_bench :
225175 benchmark_results = benchmark_module (
@@ -278,21 +228,6 @@ def invoke_with_wave_runtime(
278228 wave_runtime .launch (kernel_launch_info , kernel_args , dyn_dims , scalar_args )
279229
280230
281- def compile_and_invoke (
282- asm : str ,
283- kernel_inputs : list [torch .Tensor ],
284- kernel_outputs : list [torch .Tensor ],
285- options : WaveCompileOptions ,
286- ):
287- compiled_wave_vmfb = compile_to_vmfb (asm , options )
288- invoke_vmfb (
289- compiled_wave_vmfb ,
290- options ,
291- kernel_inputs ,
292- kernel_outputs ,
293- )
294-
295-
296231def get_default_arch () -> str :
297232 """Return default ROCM architecture"""
298233 if not torch .cuda .is_available ():
0 commit comments