77import sys
88import textwrap
99from itertools import chain , count
10- from typing import Any , Callable , Optional , Protocol , TYPE_CHECKING , Union
10+ from typing import Callable , Optional , Protocol , TYPE_CHECKING , Union
1111
1212import sympy
1313
1414import torch
15- import torch ._higher_order_ops .torchbind
1615import torch ._inductor .async_compile # noqa: F401 required to warm up AsyncCompile pools
1716import torch ._ops
1817from torch ._inductor .runtime .runtime_utils import dynamo_timed
3938
4039 from ..graph import GraphLowering
4140
42- # At most, the list nesting can go one layer deep.
43- _OUTPUT_ARGS_TYPE = list [Union [Optional [str ], list [Optional [str ]]]]
44-
4541
4642class HasWriteLine (Protocol ):
4743 def writeline (self , line : Union [LineContext , DeferredLineBase , str ]) -> None : ...
@@ -1884,18 +1880,17 @@ def codegen_while_loop(self, while_loop):
18841880
18851881 def generate_extern_kernel_args_decl_if_needed (
18861882 self ,
1887- op_overload : Union [ torch . _ops . OpOverload , torch . _ops . HigherOrderOperator ] ,
1888- raw_args : Sequence [ Any ] ,
1889- output_args : _OUTPUT_ARGS_TYPE ,
1890- raw_outputs : Sequence [ ir .Buffer ],
1883+ op_overload ,
1884+ raw_args ,
1885+ output_args : Optional [ list [ str ]] = None ,
1886+ raw_outputs : Optional [ list [ ir .Buffer ]] = None ,
18911887 ):
18921888 schema = None
18931889 if isinstance (op_overload , torch ._higher_order_ops .torchbind .CallTorchBind ):
18941890 obj = raw_args [0 ]
18951891 method = raw_args [1 ]
18961892 schema = op_overload .schema (obj , method )
18971893 else :
1898- assert isinstance (op_overload , torch ._ops .OpOverload ), type (op_overload )
18991894 schema = op_overload ._schema
19001895 assert schema is not None
19011896 arg_types = [x .real_type for x in schema .arguments ]
@@ -1991,9 +1986,7 @@ def fill_args(arg, arg_type):
19911986 else :
19921987 fill_args (arg , arg_type )
19931988
1994- def fill_output_arg (
1995- arg : str , return_type : torch .JitType , is_mutated_output : bool
1996- ) -> None :
1989+ def fill_output_arg (arg , return_type , is_mutated_output : bool ):
19971990 if isinstance (return_type , torch .TensorType ):
19981991 if not is_mutated_output :
19991992 self .writeline (f"AtenTensorHandle { arg } _handle; // output buffer" )
@@ -2028,9 +2021,8 @@ def fill_output_arg(
20282021 # None output is supported, but Optional return types are not yet supported
20292022 if output_arg is None :
20302023 continue
2031- elif isinstance (output_arg , list ):
2024+ elif isinstance (output_arg , ( list , tuple ) ):
20322025 for out in output_arg :
2033- assert out is not None , out
20342026 fill_output_arg (
20352027 out ,
20362028 torch .TensorType .get (),
@@ -2049,73 +2041,73 @@ def generate_fallback_kernel_with_runtime_lookup(
20492041 self ,
20502042 buf_name : str ,
20512043 python_kernel_name : str ,
2052- codegen_args : Sequence [str ],
2053- op_overload : Union [torch ._ops .OpOverload , torch ._ops .HigherOrderOperator ],
2054- raw_args : Sequence [Any ],
2055- outputs : Sequence [ir .Buffer ],
2056- ) -> None :
2057- """Generate a call to a kernel not contained in the C-shim. This results in
2058- different code paths for AOT Inductor vs cpp_wrapper Inductor mode."""
2059-
2060- def extract_output_name (
2061- out : Optional [Union [ir .Buffer , Sequence [ir .Buffer ]]],
2062- ) -> Union [Optional [str ], _OUTPUT_ARGS_TYPE ]:
2044+ cpp_kernel_name : str ,
2045+ codegen_args : list [str ],
2046+ op_overload : Optional [torch ._ops .OpOverload ] = None ,
2047+ raw_args = None ,
2048+ outputs = None ,
2049+ ):
2050+ def extract_output_name (out ):
20632051 if out is None :
20642052 return None
2065- if isinstance (out , (ir .MultiOutput , ir ._CollectiveKernel )):
2053+ elif isinstance (out , (ir .MultiOutput , ir ._CollectiveKernel )):
20662054 return out .get_name ()
2067- if isinstance (out , ir .MutationOutput ):
2055+ elif isinstance (out , ir .MutationOutput ):
20682056 mutated_buf_names = out .get_mutation_names ()
20692057 assert (
20702058 isinstance (mutated_buf_names , list ) and len (mutated_buf_names ) == 1
20712059 ), "Expect only one mutated buffer in MutationOutput"
20722060 return mutated_buf_names [0 ]
2073- if isinstance (out , (list , tuple )):
2074- return [extract_output_name (o ) for o in out ] # type: ignore[misc]
2075- raise AssertionError (f"Unexpected output: { type (out )} " )
2076-
2077- if isinstance (op_overload , torch ._ops .HigherOrderOperator ):
2078- assert isinstance (
2079- op_overload , torch ._higher_order_ops .torchbind .CallTorchBind
2080- ), type (op_overload )
2081- assert len (raw_args ) > 1
2082- obj = raw_args [0 ]
2083- method = raw_args [1 ]
2084- return_schema = op_overload .schema (obj , method ).returns
2085- else :
2086- return_schema = op_overload ._schema .returns
2061+ elif isinstance (out , (list , tuple )):
2062+ return type (out )(extract_output_name (o ) for o in out )
2063+ else :
2064+ raise AssertionError (f"Unexpected output: { type (out )} " )
20872065
20882066 # output_args has the same pytree structure as outputs
2089- if not return_schema :
2067+
2068+ return_schema = None
2069+ if op_overload :
2070+ if isinstance (op_overload , torch ._higher_order_ops .torchbind .CallTorchBind ):
2071+ assert raw_args is not None
2072+ assert len (raw_args ) > 1
2073+ obj = raw_args [0 ]
2074+ method = raw_args [1 ]
2075+ return_schema = op_overload .schema (obj , method ).returns
2076+ else :
2077+ return_schema = op_overload ._schema .returns
2078+ if op_overload and not return_schema :
20902079 # kernel does not return a value
2091- output_args : _OUTPUT_ARGS_TYPE = []
2092- elif isinstance (output_name := extract_output_name (outputs ), str ):
2093- output_args = [output_name ]
2080+ output_args = []
2081+ elif outputs is None :
2082+ # outputs is not specified, the default is to write to buf_name
2083+ output_args = [buf_name ]
20942084 else :
2095- # If the schema indicates a return value, we should have a non-None value by
2096- # this point.
2097- assert isinstance (output_name , list ), type (output_name )
2098- output_args = output_name
2085+ output_args = extract_output_name (outputs )
2086+ if isinstance (output_args , str ):
2087+ output_args = [output_args ]
20992088
2100- # In AOT mode, we use a ProxyExecutor to run fallback kernels.
21012089 if V .graph .aot_mode :
2102- self .generate_fallback_kernel_with_runtime_lookup_aot (
2090+ assert op_overload is not None
2091+ assert raw_args is not None
2092+ assert output_args is not None
2093+
2094+ return self .generate_fallback_kernel_with_runtime_lookup_aot (
21032095 op_overload ,
21042096 raw_args ,
21052097 output_args ,
21062098 outputs ,
21072099 )
2108- return
2109-
2110- assert isinstance ( op_overload , torch . _ops . OpOverload ), type ( op_overload )
2111- self . generate_fallback_kernel_with_runtime_lookup_jit (
2112- buf_name ,
2113- python_kernel_name ,
2114- op_overload ,
2115- raw_args ,
2116- output_args , # type: ignore[arg-type]
2117- outputs ,
2118- )
2100+ else :
2101+ return self . generate_fallback_kernel_with_runtime_lookup_jit (
2102+ buf_name ,
2103+ python_kernel_name ,
2104+ cpp_kernel_name ,
2105+ codegen_args ,
2106+ op_overload ,
2107+ raw_args ,
2108+ output_args , # type: ignore[arg-type]
2109+ outputs ,
2110+ )
21192111
21202112 def generate_scoped_gil_acquire (self , declarations_before_scope , lines_in_scope ):
21212113 scoped_lines = IndentedBuffer ()
@@ -2264,19 +2256,19 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22642256 self ,
22652257 buf_name : str ,
22662258 python_kernel_name : str ,
2267- op_overload : torch ._ops .OpOverload ,
2268- raw_args : Sequence [Any ],
2269- output_args : Sequence [Optional [str ]],
2270- raw_outputs : Sequence [ir .Buffer ],
2271- ) -> None :
2272- """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can
2273- only be called in cpp_wrapper mode, and assumes that the input is a non-None
2274- OpOverload.
2275-
2276- This function calls into Python to dispatch, which allows it to handle datatypes
2277- that cannot be contained in StableIValue, at the cost of some performance."""
2259+ cpp_kernel_name : str ,
2260+ codegen_args : list [str ],
2261+ op_overload : Optional [torch ._ops .OpOverload ] = None ,
2262+ raw_args = None ,
2263+ output_args : Optional [list [Optional [str ]]] = None ,
2264+ raw_outputs : Optional [list [ir .Buffer ]] = None ,
2265+ ):
2266+ # In the JIT mode, because of the ABI-compatible requirement, we can't directly call
2267+ # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python
2268+ # to invoke this custom op.
22782269 self .load_custom_op_wrapper ()
22792270
2271+ assert output_args is not None , "output_args should not be None"
22802272 num_args = len (raw_args )
22812273 py_args_var = f"py_args_{ next (self .arg_var_id )} "
22822274 # First arg is always the python op name
@@ -2290,6 +2282,8 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22902282 """
22912283 )
22922284
2285+ assert op_overload is not None , "op_overload should not be None"
2286+
22932287 for idx , (raw_arg , schema_arg ) in enumerate (
22942288 zip (raw_args , op_overload ._schema .arguments )
22952289 ):
@@ -2340,11 +2334,11 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
23402334
23412335 def generate_fallback_kernel_with_runtime_lookup_aot (
23422336 self ,
2343- op_overload : Union [ torch . _ops . OpOverload , torch . _ops . HigherOrderOperator ] ,
2344- raw_args : Sequence [ Any ],
2345- output_args : _OUTPUT_ARGS_TYPE ,
2346- raw_outputs : Sequence [ ir .Buffer ],
2347- ) -> None :
2337+ op_overload ,
2338+ raw_args , # contains both args and flatten kwargs
2339+ output_args : Optional [ list [ str ]] = None ,
2340+ raw_outputs : Optional [ list [ ir .Buffer ]] = None ,
2341+ ):
23482342 (
23492343 tensor_call_args ,
23502344 int_call_args ,
0 commit comments