Skip to content

Commit 95448b2

Browse files
Revert "[Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (pytorch#154371)"
This reverts commit 65b1aed. Reverted pytorch#154371 on behalf of https://github.com/clee2000 due to see henry's comment above. This was reverted internally because it causes a memory leak and OOMs on AMD? ([comment](pytorch#154371 (comment)))
1 parent 30293b8 commit 95448b2

File tree

5 files changed

+175
-132
lines changed

5 files changed

+175
-132
lines changed

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 73 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
import sys
88
import textwrap
99
from itertools import chain, count
10-
from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, Union
10+
from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
1111

1212
import sympy
1313

1414
import torch
15-
import torch._higher_order_ops.torchbind
1615
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
1716
import torch._ops
1817
from torch._inductor.runtime.runtime_utils import dynamo_timed
@@ -39,9 +38,6 @@
3938

4039
from ..graph import GraphLowering
4140

42-
# At most, the list nesting can go one layer deep.
43-
_OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
44-
4541

4642
class HasWriteLine(Protocol):
4743
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
@@ -1884,18 +1880,17 @@ def codegen_while_loop(self, while_loop):
18841880

18851881
def generate_extern_kernel_args_decl_if_needed(
18861882
self,
1887-
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
1888-
raw_args: Sequence[Any],
1889-
output_args: _OUTPUT_ARGS_TYPE,
1890-
raw_outputs: Sequence[ir.Buffer],
1883+
op_overload,
1884+
raw_args,
1885+
output_args: Optional[list[str]] = None,
1886+
raw_outputs: Optional[list[ir.Buffer]] = None,
18911887
):
18921888
schema = None
18931889
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
18941890
obj = raw_args[0]
18951891
method = raw_args[1]
18961892
schema = op_overload.schema(obj, method)
18971893
else:
1898-
assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload)
18991894
schema = op_overload._schema
19001895
assert schema is not None
19011896
arg_types = [x.real_type for x in schema.arguments]
@@ -1991,9 +1986,7 @@ def fill_args(arg, arg_type):
19911986
else:
19921987
fill_args(arg, arg_type)
19931988

1994-
def fill_output_arg(
1995-
arg: str, return_type: torch.JitType, is_mutated_output: bool
1996-
) -> None:
1989+
def fill_output_arg(arg, return_type, is_mutated_output: bool):
19971990
if isinstance(return_type, torch.TensorType):
19981991
if not is_mutated_output:
19991992
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
@@ -2028,9 +2021,8 @@ def fill_output_arg(
20282021
# None output is supported, but Optional return types are not yet supported
20292022
if output_arg is None:
20302023
continue
2031-
elif isinstance(output_arg, list):
2024+
elif isinstance(output_arg, (list, tuple)):
20322025
for out in output_arg:
2033-
assert out is not None, out
20342026
fill_output_arg(
20352027
out,
20362028
torch.TensorType.get(),
@@ -2049,73 +2041,73 @@ def generate_fallback_kernel_with_runtime_lookup(
20492041
self,
20502042
buf_name: str,
20512043
python_kernel_name: str,
2052-
codegen_args: Sequence[str],
2053-
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
2054-
raw_args: Sequence[Any],
2055-
outputs: Sequence[ir.Buffer],
2056-
) -> None:
2057-
"""Generate a call to a kernel not contained in the C-shim. This results in
2058-
different code paths for AOT Inductor vs cpp_wrapper Inductor mode."""
2059-
2060-
def extract_output_name(
2061-
out: Optional[Union[ir.Buffer, Sequence[ir.Buffer]]],
2062-
) -> Union[Optional[str], _OUTPUT_ARGS_TYPE]:
2044+
cpp_kernel_name: str,
2045+
codegen_args: list[str],
2046+
op_overload: Optional[torch._ops.OpOverload] = None,
2047+
raw_args=None,
2048+
outputs=None,
2049+
):
2050+
def extract_output_name(out):
20632051
if out is None:
20642052
return None
2065-
if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
2053+
elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
20662054
return out.get_name()
2067-
if isinstance(out, ir.MutationOutput):
2055+
elif isinstance(out, ir.MutationOutput):
20682056
mutated_buf_names = out.get_mutation_names()
20692057
assert (
20702058
isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1
20712059
), "Expect only one mutated buffer in MutationOutput"
20722060
return mutated_buf_names[0]
2073-
if isinstance(out, (list, tuple)):
2074-
return [extract_output_name(o) for o in out] # type: ignore[misc]
2075-
raise AssertionError(f"Unexpected output: {type(out)}")
2076-
2077-
if isinstance(op_overload, torch._ops.HigherOrderOperator):
2078-
assert isinstance(
2079-
op_overload, torch._higher_order_ops.torchbind.CallTorchBind
2080-
), type(op_overload)
2081-
assert len(raw_args) > 1
2082-
obj = raw_args[0]
2083-
method = raw_args[1]
2084-
return_schema = op_overload.schema(obj, method).returns
2085-
else:
2086-
return_schema = op_overload._schema.returns
2061+
elif isinstance(out, (list, tuple)):
2062+
return type(out)(extract_output_name(o) for o in out)
2063+
else:
2064+
raise AssertionError(f"Unexpected output: {type(out)}")
20872065

20882066
# output_args has the same pytree structure as outputs
2089-
if not return_schema:
2067+
2068+
return_schema = None
2069+
if op_overload:
2070+
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
2071+
assert raw_args is not None
2072+
assert len(raw_args) > 1
2073+
obj = raw_args[0]
2074+
method = raw_args[1]
2075+
return_schema = op_overload.schema(obj, method).returns
2076+
else:
2077+
return_schema = op_overload._schema.returns
2078+
if op_overload and not return_schema:
20902079
# kernel does not return a value
2091-
output_args: _OUTPUT_ARGS_TYPE = []
2092-
elif isinstance(output_name := extract_output_name(outputs), str):
2093-
output_args = [output_name]
2080+
output_args = []
2081+
elif outputs is None:
2082+
# outputs is not specified, the default is to write to buf_name
2083+
output_args = [buf_name]
20942084
else:
2095-
# If the schema indicates a return value, we should have a non-None value by
2096-
# this point.
2097-
assert isinstance(output_name, list), type(output_name)
2098-
output_args = output_name
2085+
output_args = extract_output_name(outputs)
2086+
if isinstance(output_args, str):
2087+
output_args = [output_args]
20992088

2100-
# In AOT mode, we use a ProxyExecutor to run fallback kernels.
21012089
if V.graph.aot_mode:
2102-
self.generate_fallback_kernel_with_runtime_lookup_aot(
2090+
assert op_overload is not None
2091+
assert raw_args is not None
2092+
assert output_args is not None
2093+
2094+
return self.generate_fallback_kernel_with_runtime_lookup_aot(
21032095
op_overload,
21042096
raw_args,
21052097
output_args,
21062098
outputs,
21072099
)
2108-
return
2109-
2110-
assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload)
2111-
self.generate_fallback_kernel_with_runtime_lookup_jit(
2112-
buf_name,
2113-
python_kernel_name,
2114-
op_overload,
2115-
raw_args,
2116-
output_args, # type: ignore[arg-type]
2117-
outputs,
2118-
)
2100+
else:
2101+
return self.generate_fallback_kernel_with_runtime_lookup_jit(
2102+
buf_name,
2103+
python_kernel_name,
2104+
cpp_kernel_name,
2105+
codegen_args,
2106+
op_overload,
2107+
raw_args,
2108+
output_args, # type: ignore[arg-type]
2109+
outputs,
2110+
)
21192111

21202112
def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope):
21212113
scoped_lines = IndentedBuffer()
@@ -2264,19 +2256,19 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22642256
self,
22652257
buf_name: str,
22662258
python_kernel_name: str,
2267-
op_overload: torch._ops.OpOverload,
2268-
raw_args: Sequence[Any],
2269-
output_args: Sequence[Optional[str]],
2270-
raw_outputs: Sequence[ir.Buffer],
2271-
) -> None:
2272-
"""Generate fallback kernel calls with runtime (non-AOT) dispatch. This can
2273-
only be called in cpp_wrapper mode, and assumes that the input is a non-None
2274-
OpOverload.
2275-
2276-
This function calls into Python to dispatch, which allows it to handle datatypes
2277-
that cannot be contained in StableIValue, at the cost of some performance."""
2259+
cpp_kernel_name: str,
2260+
codegen_args: list[str],
2261+
op_overload: Optional[torch._ops.OpOverload] = None,
2262+
raw_args=None,
2263+
output_args: Optional[list[Optional[str]]] = None,
2264+
raw_outputs: Optional[list[ir.Buffer]] = None,
2265+
):
2266+
# In the JIT mode, because of the ABI-compatible requirement, we can't directly call
2267+
# c10::Dispatcher to find the custom op and call it. Instead, we go back to Python
2268+
# to invoke this custom op.
22782269
self.load_custom_op_wrapper()
22792270

2271+
assert output_args is not None, "output_args should not be None"
22802272
num_args = len(raw_args)
22812273
py_args_var = f"py_args_{next(self.arg_var_id)}"
22822274
# First arg is always the python op name
@@ -2290,6 +2282,8 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22902282
"""
22912283
)
22922284

2285+
assert op_overload is not None, "op_overload should not be None"
2286+
22932287
for idx, (raw_arg, schema_arg) in enumerate(
22942288
zip(raw_args, op_overload._schema.arguments)
22952289
):
@@ -2340,11 +2334,11 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
23402334

23412335
def generate_fallback_kernel_with_runtime_lookup_aot(
23422336
self,
2343-
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
2344-
raw_args: Sequence[Any],
2345-
output_args: _OUTPUT_ARGS_TYPE,
2346-
raw_outputs: Sequence[ir.Buffer],
2347-
) -> None:
2337+
op_overload,
2338+
raw_args, # contains both args and flatten kwargs
2339+
output_args: Optional[list[str]] = None,
2340+
raw_outputs: Optional[list[ir.Buffer]] = None,
2341+
):
23482342
(
23492343
tensor_call_args,
23502344
int_call_args,

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# mypy: allow-untyped-defs
2-
from collections.abc import Sequence
3-
from typing import Any, Callable, Optional, Union
2+
from typing import Callable, Optional
43

54
import sympy
65

@@ -750,16 +749,57 @@ def generate_fallback_kernel_with_runtime_lookup(
750749
self,
751750
buf_name: str,
752751
python_kernel_name: str,
753-
codegen_args: Sequence[str],
754-
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
755-
raw_args: Sequence[Any],
756-
outputs: Sequence[ir.Buffer],
757-
) -> None:
752+
cpp_kernel_name: str,
753+
codegen_args: list[str],
754+
op_overload: Optional[torch._ops.OpOverload] = None,
755+
raw_args=None,
756+
outputs=None,
757+
):
758758
# No stack allocation when there is a fallback op
759759
self.allow_stack_allocation = False
760-
super().generate_fallback_kernel_with_runtime_lookup(
761-
buf_name, python_kernel_name, codegen_args, op_overload, raw_args, outputs
762-
)
760+
761+
def extract_output_name(out):
762+
if out is None:
763+
return None
764+
elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
765+
return out.get_name()
766+
elif isinstance(out, (list, tuple)):
767+
return type(out)(extract_output_name(o) for o in out)
768+
else:
769+
raise AssertionError(f"Unexpected output: {type(out)}")
770+
771+
# output_args has the same pytree structure as outputs
772+
output_args = None
773+
if outputs is None:
774+
# outputs is not specified, the default is to write to buf_name
775+
output_args = [buf_name]
776+
else:
777+
output_args = extract_output_name(outputs)
778+
if isinstance(output_args, str):
779+
output_args = [output_args]
780+
781+
if V.graph.aot_mode:
782+
assert op_overload is not None
783+
assert raw_args is not None
784+
assert outputs is not None
785+
786+
return self.generate_fallback_kernel_with_runtime_lookup_aot(
787+
op_overload,
788+
raw_args,
789+
output_args,
790+
outputs,
791+
)
792+
else:
793+
return self.generate_fallback_kernel_with_runtime_lookup_jit(
794+
buf_name,
795+
python_kernel_name,
796+
cpp_kernel_name,
797+
codegen_args,
798+
op_overload,
799+
raw_args,
800+
output_args, # type: ignore[arg-type]
801+
outputs,
802+
)
763803

764804
def codegen_device_copy(self, src, dst, non_blocking: bool):
765805
# aoti_torch_tensor_copy_ takes AtenTensorHandle as input,

torch/_inductor/codegen/wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,11 +1417,12 @@ def generate_fallback_kernel_with_runtime_lookup(
14171417
self,
14181418
buf_name: str,
14191419
python_kernel_name: str,
1420-
codegen_args: Sequence[str],
1421-
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
1422-
raw_args: Sequence[Any],
1423-
outputs: Sequence[ir.Buffer],
1424-
) -> None:
1420+
cpp_kernel_name: str,
1421+
codegen_args: list[str],
1422+
op_overload: Optional[torch._ops.OpOverload] = None,
1423+
raw_args=None,
1424+
outputs=None,
1425+
):
14251426
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})")
14261427

14271428
def generate(self, is_inference):

0 commit comments

Comments
 (0)