@@ -822,20 +822,12 @@ def codegen_tensor_item(
822822 if dtype == torch .float16 or dtype == torch .bfloat16 :
823823 scalar_tmp = f"{ scalar } _tmp"
824824 writer .writeline (f"{ DTYPE_TO_CPP [dtype ]} { scalar_tmp } ;" )
825-
826- # need convert_arrayref_tensor_to_tensor for ArrayRefTensors
827- tensor = f"convert_arrayref_tensor_to_tensor({ tensor } )"
828-
829825 writer .writeline (
830826 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{ dtype_str } ({ tensor } , &{ scalar_tmp } ));"
831827 )
832828 writer .writeline (f"float { scalar } = float({ scalar_tmp } );" )
833829 else :
834830 writer .writeline (f"{ DTYPE_TO_CPP [dtype ]} { scalar } ;" )
835-
836- # need convert_arrayref_tensor_to_tensor for ArrayRefTensors
837- tensor = f"convert_arrayref_tensor_to_tensor({ tensor } )"
838-
839831 writer .writeline (
840832 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{ dtype_str } ({ tensor } , &{ scalar } ));"
841833 )
@@ -939,6 +931,7 @@ def generate_end(self, result):
939931 f"""
940932 '''
941933 )
934+
942935 inductor_entry = CppWrapperCodeCache.load_pybinding(
943936 ["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "{ self .device } ", { len (V .graph .graph_outputs )} )
944937 """
@@ -1014,29 +1007,12 @@ def get_c_shim_func_name(self, kernel):
10141007 return shim_fn
10151008
10161009 def generate_c_shim_extern_kernel_call (self , kernel , args ):
1017- wrapped_args = []
10181010 debug_printer_manager = V .graph .wrapper_code .debug_printer
1019-
1020- for x in args :
1021- pieces = x .split (", " )
1022- for piece in pieces :
1023- # We only really *need* convert_arrayref_tensor_to_tensor for
1024- # ArrayRefTensors. The code flowing into here uses `0` for nullptr,
1025- # which convert_arrayref_tensor_to_tensor would blindly coerce to int,
1026- # so just avoid wrapping integers.
1027- # Name matching is to find tensor is hacky, but fixing all the
1028- # ArrayRefTensor issues is not a priority for now.
1029- if isinstance (piece , str ) and piece .startswith (
1030- ("buf" , "arg" , "wrap_with_raii_handle_if_needed" )
1031- ):
1032- piece = f"convert_arrayref_tensor_to_tensor({ piece } )"
1033- wrapped_args .append (piece )
1034-
10351011 debug_printer_manager .set_printer_args (args , kernel , None , None , "extern" )
10361012 with debug_printer_manager :
10371013 shim_fn = self .get_c_shim_func_name (kernel )
10381014 self .writeline (
1039- f"AOTI_TORCH_ERROR_CODE_CHECK({ shim_fn } ({ ', ' .join (wrapped_args )} ));"
1015+ f"AOTI_TORCH_ERROR_CODE_CHECK({ shim_fn } ({ ', ' .join (args )} ));"
10401016 )
10411017
10421018 def generate_c_shim_extern_kernel_alloc (self , extern_kernel , args ):
@@ -1121,15 +1097,8 @@ def generate_scatter_fallback(
11211097 cpp_kernel_name = self .get_c_shim_func_name (cpp_kernel_name )
11221098 # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
11231099 cpp_kernel_name = cpp_kernel_name .replace ("__" , "_" ) + "_out"
1124- inputs_wrapped = [
1125- (
1126- f"convert_arrayref_tensor_to_tensor({ x } )"
1127- if isinstance (x , str )
1128- else str (x )
1129- )
1130- for x in inputs
1131- ]
1132- line = f"{ cpp_kernel_name } (convert_arrayref_tensor_to_tensor({ output } ), { ',' .join (inputs_wrapped )} "
1100+ inputs_wrapped = [str (x ) for x in inputs ]
1101+ line = f"{ cpp_kernel_name } ({ output } , { ',' .join (inputs_wrapped )} "
11331102
11341103 if python_kernel_name .startswith ("aten.scatter_reduce" ):
11351104 line += f", { ',' .join (kwargs )} "
@@ -1150,25 +1119,16 @@ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
11501119 # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
11511120 # tensor prematurely deallocated, thus this std::vector().data() trick here.
11521121 indices_str = (
1153- "std::vector<AtenTensorHandle>{"
1154- + (
1155- ", " .join (
1156- [f"convert_arrayref_tensor_to_tensor({ ind } )" for ind in indices ]
1157- )
1158- )
1159- + "}.data()"
1122+ "std::vector<AtenTensorHandle>{" + (", " .join (indices )) + "}.data()"
11601123 )
11611124 args = [
1162- f"convert_arrayref_tensor_to_tensor( { x } )" ,
1125+ x ,
11631126 indices_str ,
11641127 str (len (indices )),
1165- f"convert_arrayref_tensor_to_tensor( { values } )" ,
1128+ values ,
11661129 accumulate ,
11671130 ]
1168- args .insert (
1169- 0 , f"convert_arrayref_tensor_to_tensor({ x } )"
1170- ) # set x as the output tensor, this fallback mutates x.
1171-
1131+ args .insert (0 , x ) # set x as the output tensor, this fallback mutates x.
11721132 self .writeline (self .wrap_kernel_call (kernel , args ))
11731133
11741134 def add_benchmark_harness (self , output ):
@@ -1362,12 +1322,10 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
13621322 return f"RAIIAtenTensorHandle({ tmp_name } )"
13631323
13641324 def codegen_reinterpret_view (
1365- self , data , size_list , stride_list , offset , writer , dtype = None
1325+ self , data , size , stride , offset , writer , dtype = None
13661326 ) -> str :
1367- dim = str (len (size_list ))
1327+ dim = str (len (size ))
13681328 original_offset = offset
1369- size = self .codegen_shape_tuple (size_list )
1370- stride = self .codegen_shape_tuple (stride_list )
13711329 offset = self .codegen_sizevar (offset )
13721330 call_strs = []
13731331 final_tmp_name = None
@@ -1379,15 +1337,15 @@ def create_reinterpret_call() -> Tuple[str, str]:
13791337 f"{ data .get_name ()} " ,
13801338 dim ,
13811339 self .codegen_int_array_var (
1382- size ,
1340+ self . codegen_shape_tuple ( size ) ,
13831341 writer ,
1384- known_statically = self .is_statically_known_list_of_ints (size_list ),
1342+ known_statically = self .is_statically_known_list_of_ints (size ),
13851343 graph = self .get_codegened_graph (),
13861344 ),
13871345 self .codegen_int_array_var (
1388- stride ,
1346+ self . codegen_shape_tuple ( stride ) ,
13891347 writer ,
1390- known_statically = self .is_statically_known_list_of_ints (stride_list ),
1348+ known_statically = self .is_statically_known_list_of_ints (stride ),
13911349 graph = self .get_codegened_graph (),
13921350 ),
13931351 offset ,
@@ -1417,8 +1375,8 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
14171375 return tmp_RAIIAtenTensorHandle , call_strs
14181376
14191377 if (
1420- size_list == data .layout .size
1421- and stride_list == data .layout .stride
1378+ size == data .layout .size
1379+ and stride == data .layout .stride
14221380 and original_offset == data .layout .offset
14231381 ):
14241382 # pure dtypeview
@@ -1428,7 +1386,7 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
14281386 final_tmp_name = tmp_output_name
14291387 final_tmp_name_is_RAIIAtenTensorHandle = True
14301388 else :
1431- return f" { data .get_name ()} "
1389+ return data .get_name ()
14321390 else :
14331391 # firstly create reinterpretview
14341392 final_tmp_name , reinterpret_call = create_reinterpret_call ()
@@ -2230,12 +2188,7 @@ def val_to_arg_str(self, val, type_=None) -> str:
22302188 return self .val_to_arg_str_for_prim_type (val , type_ )
22312189
22322190 def create_tmp_raii_handle_var (self , base_handle ):
2233- if base_handle .startswith (
2234- (
2235- "convert_arrayref_tensor_to_tensor" ,
2236- "wrap_with_raii_handle_if_needed" ,
2237- )
2238- ):
2191+ if base_handle .startswith (("wrap_with_raii_handle_if_needed" ,)):
22392192 # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
22402193 # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
22412194 tmp_var_name = f"var_{ next (self .arg_var_id )} "
0 commit comments