Skip to content

Commit 3abd543

Browse files
authored
[TKW] Drop inplace flag (iree-org#863)
`generate_iree_ref` was switched to `turbine.runtime` launcher (iree-org#861) and it was the last user of `inplace=False` option, drop the flag and related code. `test_scalar_codegen` doesn't really use this flag as it isn't run with iree runtime currently. --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent 08e6244 commit 3abd543

File tree

4 files changed

+16
-83
lines changed

4 files changed

+16
-83
lines changed

iree/turbine/kernel/wave/compile_options.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class WaveCompileOptions:
2828
# === Runtime options ===
2929
kernel_launch_info: KernelLaunchInfo = field(default_factory=KernelLaunchInfo)
3030
kernel_usages: tuple[KernelBufferUsage] = None
31-
inplace: bool = True
3231

3332
# === Backend options ===
3433
backend: str = "rocm"

iree/turbine/kernel/wave/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def construct_inputs(
3131
bench_with_constant_weights = options.bench_with_constant_weights
3232
tempfiles = []
3333
inputs = []
34-
all_inputs = kernel_inputs + kernel_outputs if options.inplace else kernel_inputs
34+
all_inputs = kernel_inputs + kernel_outputs
3535
all_inputs += options.dynamic_symbols_map.values()
3636
if bench_with_constant_weights:
3737
for inp in all_inputs:

iree/turbine/kernel/wave/utils/run_utils.py

Lines changed: 15 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from typing import Callable, Optional, Any
1111
import ctypes
1212
from ..compile_options import WaveCompileOptions
13-
from .compile_utils import compile_to_vmfb
14-
from .classes import KernelLaunchInfo
1513
from ..profiling import benchmark_module
1614

1715

@@ -24,12 +22,6 @@ def compute_grid(kernel_dynamic_dims: tuple[int], grid_fn: Callable):
2422
return [int(x) for x in grid_fn(list(kernel_dynamic_dims))]
2523

2624

27-
def _read_file(name, mode):
28-
with open(name, mode) as file:
29-
data = file.read()
30-
return data
31-
32-
3325
def _write_file(name, mode, data):
3426
with open(name, mode) as file:
3527
file.write(data)
@@ -50,37 +42,6 @@ def get_device_uuid(device_list: list[str], device_str: str) -> tuple[int, str]:
5042
return device_str
5143

5244

53-
def _invoke(vm_context, device, entry_function, inputs, outputs, dynamic_dims):
54-
arg_list = rt.VmVariantList(len(inputs) + len(dynamic_dims))
55-
ret_list = rt.VmVariantList(len(outputs))
56-
57-
for input in inputs:
58-
if isinstance(input, torch.Tensor):
59-
input_cpu = input.cpu().contiguous()
60-
device_array = rt.asdevicearray(device, input_cpu)
61-
arg_list.push_ref(device_array._buffer_view)
62-
else:
63-
raise ValueError(f"Unsupported input type: {type(input)}")
64-
65-
for dynamic_dim in dynamic_dims:
66-
if isinstance(dynamic_dim, int):
67-
arg_list.push_int(dynamic_dim)
68-
else:
69-
raise ValueError(f"Unsupported dynamic dim type: {type(dynamic_dim)}")
70-
71-
vm_context.invoke(entry_function, arg_list, ret_list)
72-
73-
for i, ret in enumerate(outputs):
74-
device_buffer_view = rt.HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i))
75-
device_array = rt.DeviceArray(device, device_buffer_view)
76-
77-
# TODO: Make to_host accept out array/buffer, so we can avoid extra data copy.
78-
host_array = device_array.to_host()
79-
80-
# Convert to torch tensor without actually importing torch.
81-
ret[:] = type(ret)(host_array)
82-
83-
8445
_dl_tensor_name = ctypes.create_string_buffer(b"dltensor")
8546
_set_capsule_name = ctypes.pythonapi.PyCapsule_SetName
8647

@@ -173,14 +134,13 @@ def invoke_vmfb(
173134
options.benchmark_repetitions
174135
)
175136

176-
if options.inplace:
177-
# Select device as the GPU, where input tensors are coming from.
178-
device_list = tuple(
179-
input.device
180-
for input in kernel_inputs + kernel_outputs
181-
if isinstance(input, torch.Tensor)
182-
)
183-
device = get_device_uuid(device_list, device)
137+
# Select device as the GPU, where input tensors are coming from.
138+
device_list = tuple(
139+
input.device
140+
for input in kernel_inputs + kernel_outputs
141+
if isinstance(input, torch.Tensor)
142+
)
143+
device = get_device_uuid(device_list, device)
184144

185145
rt_config = rt.Config(device)
186146
device = rt_config.device
@@ -202,24 +162,14 @@ def invoke_vmfb(
202162
if options.kernel_hash:
203163
RUNTIME_CACHE[options.kernel_hash] = (ctx, func)
204164

205-
if options.inplace:
206-
_inplace_invoke(
207-
ctx.vm_context,
208-
device,
209-
func,
210-
kernel_inputs,
211-
kernel_outputs,
212-
options.dynamic_symbols_map.values(),
213-
)
214-
else:
215-
_invoke(
216-
ctx.vm_context,
217-
device,
218-
func,
219-
kernel_inputs,
220-
kernel_outputs,
221-
options.dynamic_symbols_map.values(),
222-
)
165+
_inplace_invoke(
166+
ctx.vm_context,
167+
device,
168+
func,
169+
kernel_inputs,
170+
kernel_outputs,
171+
options.dynamic_symbols_map.values(),
172+
)
223173

224174
if options.run_bench:
225175
benchmark_results = benchmark_module(
@@ -278,21 +228,6 @@ def invoke_with_wave_runtime(
278228
wave_runtime.launch(kernel_launch_info, kernel_args, dyn_dims, scalar_args)
279229

280230

281-
def compile_and_invoke(
282-
asm: str,
283-
kernel_inputs: list[torch.Tensor],
284-
kernel_outputs: list[torch.Tensor],
285-
options: WaveCompileOptions,
286-
):
287-
compiled_wave_vmfb = compile_to_vmfb(asm, options)
288-
invoke_vmfb(
289-
compiled_wave_vmfb,
290-
options,
291-
kernel_inputs,
292-
kernel_outputs,
293-
)
294-
295-
296231
def get_default_arch() -> str:
297232
"""Return default ROCM architecture"""
298233
if not torch.cuda.is_available():

tests/kernel/wave/wave_e2e_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,6 @@ def test(
16801680
},
16811681
canonicalize=True,
16821682
run_bench=run_bench,
1683-
inplace=False,
16841683
wave_runtime=True,
16851684
)
16861685
test = wave_compile(options, test)

0 commit comments

Comments
 (0)