@@ -124,72 +124,6 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
124124 return False
125125
126126
127- def convert_arg_type (arg : torch .Argument ) -> str :
128- from .cpp import CONTAINER_PYTHON_TO_CPP , PYTHON_TO_CPP
129-
130- # use x.real_type instead of x.type so that we get ScalarType instead of int
131- python_type = repr (arg .real_type ) # type: ignore[attr-defined]
132-
133- if python_type == "Tensor" :
134- # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
135- if arg .alias_info is not None and arg .alias_info .is_write :
136- return f"at::{ python_type } &"
137- else :
138- return f"at::{ python_type } const&"
139-
140- if python_type in PYTHON_TO_CPP :
141- cpp_type = PYTHON_TO_CPP [python_type ]
142- return cpp_type
143-
144- # Convert args of container types e.g. Optional[*]
145- for py_container , cpp_container in CONTAINER_PYTHON_TO_CPP .items ():
146- container_match = re .findall (py_container + r"\[([a-zA-Z_]+)]" , python_type )
147- if len (container_match ) == 1 :
148- contained_type = container_match [0 ]
149- assert contained_type in PYTHON_TO_CPP , (
150- f"unsupported { py_container } type in convert_arg_type: { contained_type } "
151- )
152- cpp_contained_type = PYTHON_TO_CPP [contained_type ]
153- return f"{ cpp_container } <{ cpp_contained_type } >"
154-
155- raise AssertionError (f"unsupport python_type: { python_type } " )
156-
157-
158- def convert_return_type (ret : torch .Argument ) -> str :
159- # use x.real_type instead of x.type so that we get ScalarType instead of int
160- python_type = repr (ret .real_type ) # type: ignore[attr-defined]
161- python_to_cpp = {
162- "Tensor" : "at::Tensor" ,
163- "List[Tensor]" : "std::vector<at::Tensor>" ,
164- }
165-
166- cpp_type = python_to_cpp .get (python_type , None )
167- assert cpp_type is not None , f"NYI return type: { python_type } "
168- # An output aliasing an input is returned by reference only when it's a
169- # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output
170- # aliases the input tensor, but the op returns a vector by value.
171- if python_type == "Tensor" and ret .alias_info is not None :
172- cpp_type += "&"
173- return cpp_type
174-
175-
176- def get_cpp_op_schema (kernel : torch ._ops .OpOverload ) -> str :
177- args = kernel ._schema .arguments
178- returns = kernel ._schema .returns
179-
180- num_returns = len (returns )
181- assert num_returns > 0 , "must have at least one return value"
182-
183- if num_returns == 1 :
184- cpp_return_value = convert_return_type (returns [0 ])
185- elif num_returns > 1 :
186- tuple_returns = ", " .join ([convert_return_type (r ) for r in returns ])
187- cpp_return_value = f"std::tuple<{ tuple_returns } >"
188-
189- cpp_arg_type = [f"{ convert_arg_type (arg )} { arg .name } " for arg in args ]
190- return f"{ cpp_return_value } ({ ', ' .join (cpp_arg_type )} )" # type: ignore[possibly-undefined]
191-
192-
193127# TODO: Move to a well known place
194128TritonMetaParams = dict [str , int ]
195129TritonGrid = Union [
0 commit comments