Skip to content

Commit fb2c750

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI][refactor] Move convert_arrayref_tensor_to_tensor logic (pytorch#139030)
Summary: Move convert_arrayref_tensor_to_tensor codegen logic to cpp_wrapper_cpu_array_ref.py Test Plan: CI Differential Revision: D64904187 Pull Request resolved: pytorch#139030 Approved by: https://github.com/hl475
1 parent 949fdd2 commit fb2c750

File tree

2 files changed

+74
-79
lines changed

2 files changed

+74
-79
lines changed

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 18 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -822,20 +822,12 @@ def codegen_tensor_item(
822822
if dtype == torch.float16 or dtype == torch.bfloat16:
823823
scalar_tmp = f"{scalar}_tmp"
824824
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
825-
826-
# need convert_arrayref_tensor_to_tensor for ArrayRefTensors
827-
tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
828-
829825
writer.writeline(
830826
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
831827
)
832828
writer.writeline(f"float {scalar} = float({scalar_tmp});")
833829
else:
834830
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
835-
836-
# need convert_arrayref_tensor_to_tensor for ArrayRefTensors
837-
tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
838-
839831
writer.writeline(
840832
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
841833
)
@@ -939,6 +931,7 @@ def generate_end(self, result):
939931
f"""
940932
'''
941933
)
934+
942935
inductor_entry = CppWrapperCodeCache.load_pybinding(
943936
["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "{self.device}", {len(V.graph.graph_outputs)})
944937
"""
@@ -1014,29 +1007,12 @@ def get_c_shim_func_name(self, kernel):
10141007
return shim_fn
10151008

10161009
def generate_c_shim_extern_kernel_call(self, kernel, args):
1017-
wrapped_args = []
10181010
debug_printer_manager = V.graph.wrapper_code.debug_printer
1019-
1020-
for x in args:
1021-
pieces = x.split(", ")
1022-
for piece in pieces:
1023-
# We only really *need* convert_arrayref_tensor_to_tensor for
1024-
# ArrayRefTensors. The code flowing into here uses `0` for nullptr,
1025-
# which convert_arrayref_tensor_to_tensor would blindly coerce to int,
1026-
# so just avoid wrapping integers.
1027-
# Name matching is to find tensor is hacky, but fixing all the
1028-
# ArrayRefTensor issues is not a priority for now.
1029-
if isinstance(piece, str) and piece.startswith(
1030-
("buf", "arg", "wrap_with_raii_handle_if_needed")
1031-
):
1032-
piece = f"convert_arrayref_tensor_to_tensor({piece})"
1033-
wrapped_args.append(piece)
1034-
10351011
debug_printer_manager.set_printer_args(args, kernel, None, None, "extern")
10361012
with debug_printer_manager:
10371013
shim_fn = self.get_c_shim_func_name(kernel)
10381014
self.writeline(
1039-
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));"
1015+
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
10401016
)
10411017

10421018
def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
@@ -1121,15 +1097,8 @@ def generate_scatter_fallback(
11211097
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name)
11221098
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
11231099
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
1124-
inputs_wrapped = [
1125-
(
1126-
f"convert_arrayref_tensor_to_tensor({x})"
1127-
if isinstance(x, str)
1128-
else str(x)
1129-
)
1130-
for x in inputs
1131-
]
1132-
line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}"
1100+
inputs_wrapped = [str(x) for x in inputs]
1101+
line = f"{cpp_kernel_name}({output}, {','.join(inputs_wrapped)}"
11331102

11341103
if python_kernel_name.startswith("aten.scatter_reduce"):
11351104
line += f", {','.join(kwargs)}"
@@ -1150,25 +1119,16 @@ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
11501119
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
11511120
# tensor prematurely deallocated, thus this std::vector().data() trick here.
11521121
indices_str = (
1153-
"std::vector<AtenTensorHandle>{"
1154-
+ (
1155-
", ".join(
1156-
[f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices]
1157-
)
1158-
)
1159-
+ "}.data()"
1122+
"std::vector<AtenTensorHandle>{" + (", ".join(indices)) + "}.data()"
11601123
)
11611124
args = [
1162-
f"convert_arrayref_tensor_to_tensor({x})",
1125+
x,
11631126
indices_str,
11641127
str(len(indices)),
1165-
f"convert_arrayref_tensor_to_tensor({values})",
1128+
values,
11661129
accumulate,
11671130
]
1168-
args.insert(
1169-
0, f"convert_arrayref_tensor_to_tensor({x})"
1170-
) # set x as the output tensor, this fallback mutates x.
1171-
1131+
args.insert(0, x) # set x as the output tensor, this fallback mutates x.
11721132
self.writeline(self.wrap_kernel_call(kernel, args))
11731133

11741134
def add_benchmark_harness(self, output):
@@ -1362,12 +1322,10 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
13621322
return f"RAIIAtenTensorHandle({tmp_name})"
13631323

13641324
def codegen_reinterpret_view(
1365-
self, data, size_list, stride_list, offset, writer, dtype=None
1325+
self, data, size, stride, offset, writer, dtype=None
13661326
) -> str:
1367-
dim = str(len(size_list))
1327+
dim = str(len(size))
13681328
original_offset = offset
1369-
size = self.codegen_shape_tuple(size_list)
1370-
stride = self.codegen_shape_tuple(stride_list)
13711329
offset = self.codegen_sizevar(offset)
13721330
call_strs = []
13731331
final_tmp_name = None
@@ -1379,15 +1337,15 @@ def create_reinterpret_call() -> Tuple[str, str]:
13791337
f"{data.get_name()}",
13801338
dim,
13811339
self.codegen_int_array_var(
1382-
size,
1340+
self.codegen_shape_tuple(size),
13831341
writer,
1384-
known_statically=self.is_statically_known_list_of_ints(size_list),
1342+
known_statically=self.is_statically_known_list_of_ints(size),
13851343
graph=self.get_codegened_graph(),
13861344
),
13871345
self.codegen_int_array_var(
1388-
stride,
1346+
self.codegen_shape_tuple(stride),
13891347
writer,
1390-
known_statically=self.is_statically_known_list_of_ints(stride_list),
1348+
known_statically=self.is_statically_known_list_of_ints(stride),
13911349
graph=self.get_codegened_graph(),
13921350
),
13931351
offset,
@@ -1417,8 +1375,8 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
14171375
return tmp_RAIIAtenTensorHandle, call_strs
14181376

14191377
if (
1420-
size_list == data.layout.size
1421-
and stride_list == data.layout.stride
1378+
size == data.layout.size
1379+
and stride == data.layout.stride
14221380
and original_offset == data.layout.offset
14231381
):
14241382
# pure dtypeview
@@ -1428,7 +1386,7 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
14281386
final_tmp_name = tmp_output_name
14291387
final_tmp_name_is_RAIIAtenTensorHandle = True
14301388
else:
1431-
return f"{data.get_name()}"
1389+
return data.get_name()
14321390
else:
14331391
# firstly create reinterpretview
14341392
final_tmp_name, reinterpret_call = create_reinterpret_call()
@@ -2230,12 +2188,7 @@ def val_to_arg_str(self, val, type_=None) -> str:
22302188
return self.val_to_arg_str_for_prim_type(val, type_)
22312189

22322190
def create_tmp_raii_handle_var(self, base_handle):
2233-
if base_handle.startswith(
2234-
(
2235-
"convert_arrayref_tensor_to_tensor",
2236-
"wrap_with_raii_handle_if_needed",
2237-
)
2238-
):
2191+
if base_handle.startswith(("wrap_with_raii_handle_if_needed",)):
22392192
# wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
22402193
# explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
22412194
tmp_var_name = f"var_{next(self.arg_var_id)}"

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,10 @@ def codegen_device_copy(self, src, dst, non_blocking: bool):
775775
)
776776

777777
def codegen_reinterpret_view(
778-
self, data, size_list, stride_list, offset, writer, dtype=None
778+
self, data, size, stride, offset, writer, dtype=None
779779
) -> str:
780-
dim = str(len(size_list))
780+
dim = str(len(size))
781781
original_offset = offset
782-
size = self.codegen_shape_tuple(size_list)
783-
stride = self.codegen_shape_tuple(stride_list)
784782
offset = self.codegen_sizevar(offset)
785783
call_strs = []
786784
final_tmp_name = None
@@ -792,15 +790,15 @@ def create_reinterpret_call() -> Tuple[str, str]:
792790
f"{data.get_name()}",
793791
dim,
794792
self.codegen_int_array_var(
795-
size,
793+
self.codegen_shape_tuple(size),
796794
writer,
797-
known_statically=self.is_statically_known_list_of_ints(size_list),
795+
known_statically=self.is_statically_known_list_of_ints(size),
798796
graph=self.get_codegened_graph(),
799797
),
800798
self.codegen_int_array_var(
801-
stride,
799+
self.codegen_shape_tuple(stride),
802800
writer,
803-
known_statically=self.is_statically_known_list_of_ints(stride_list),
801+
known_statically=self.is_statically_known_list_of_ints(stride),
804802
graph=self.get_codegened_graph(),
805803
),
806804
offset,
@@ -830,8 +828,8 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
830828
return tmp_RAIIAtenTensorHandle, call_strs
831829

832830
if (
833-
size_list == data.layout.size
834-
and stride_list == data.layout.stride
831+
size == data.layout.size
832+
and stride == data.layout.stride
835833
and original_offset == data.layout.offset
836834
):
837835
# pure dtypeview
@@ -841,7 +839,7 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
841839
final_tmp_name = tmp_output_name
842840
final_tmp_name_is_RAIIAtenTensorHandle = True
843841
else:
844-
return f"{data.get_name()}"
842+
return data.get_name()
845843
else:
846844
# firstly create reinterpretview
847845
final_tmp_name, reinterpret_call = create_reinterpret_call()
@@ -861,9 +859,9 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]:
861859

862860
if (
863861
self.can_stack_allocate_buffer(data)
864-
and self.is_statically_known_list_of_ints(size_list)
865-
and self.is_statically_known_list_of_ints(stride_list)
866-
and ir.is_contiguous_strides_for_shape(stride_list, size_list)
862+
and self.is_statically_known_list_of_ints(size)
863+
and self.is_statically_known_list_of_ints(stride)
864+
and ir.is_contiguous_strides_for_shape(stride, size)
867865
):
868866
return final_tmp_name
869867

@@ -986,3 +984,47 @@ def val_to_arg_str(self, val, type_=None) -> str:
986984
return f"{var_name}, {len(val)}"
987985

988986
return self.val_to_arg_str_for_prim_type(val, type_)
987+
988+
def codegen_tensor_item(
989+
self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
990+
):
991+
dtype_str = str(dtype).split(".")[-1]
992+
writer = indented_buffer or self
993+
994+
if dtype == torch.float16 or dtype == torch.bfloat16:
995+
scalar_tmp = f"{scalar}_tmp"
996+
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
997+
998+
# need convert_arrayref_tensor_to_tensor for ArrayRefTensors
999+
tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
1000+
1001+
writer.writeline(
1002+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
1003+
)
1004+
writer.writeline(f"float {scalar} = float({scalar_tmp});")
1005+
else:
1006+
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
1007+
1008+
# need convert_arrayref_tensor_to_tensor for ArrayRefTensors
1009+
tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
1010+
1011+
writer.writeline(
1012+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
1013+
)
1014+
1015+
def create_tmp_raii_handle_var(self, base_handle):
1016+
if base_handle.startswith(
1017+
(
1018+
"convert_arrayref_tensor_to_tensor",
1019+
"wrap_with_raii_handle_if_needed",
1020+
)
1021+
):
1022+
# wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
1023+
# explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
1024+
tmp_var_name = f"var_{next(self.arg_var_id)}"
1025+
return (
1026+
tmp_var_name,
1027+
f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};\n",
1028+
)
1029+
else:
1030+
return "", ""

0 commit comments

Comments
 (0)