Skip to content

Commit 26471fc

Browse files
angelayipytorchmergebot
authored andcommitted
[aoti] Initial Metal support (pytorch#153959)
An example generated file: P1816629015 Pull Request resolved: pytorch#153959 Approved by: https://github.com/malfet, https://github.com/desertfire ghstack dependencies: pytorch#153964
1 parent b33b7d5 commit 26471fc

File tree

11 files changed

+165
-17
lines changed

11 files changed

+165
-17
lines changed

test/inductor/test_mps_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,22 @@ def fn(x, y):
181181
)
182182

183183

184+
class MPSBasicTestsAOTI(TestCase):
185+
def test_add_mps(self):
186+
class M(torch.nn.Module):
187+
def forward(self, x, y):
188+
return x + y
189+
190+
inp = (torch.ones(3, 3, device="mps"), torch.ones(3, 3, device="mps"))
191+
m = M().to("mps")
192+
res2 = m(*inp)
193+
ep = torch.export.export(m, inp)
194+
path = torch._inductor.aoti_compile_and_package(ep, "here.pt2")
195+
m = torch._inductor.aoti_load_package(path)
196+
res = m(*inp)
197+
assert torch.allclose(res, res2)
198+
199+
184200
if __name__ == "__main__":
185201
from torch._dynamo.test_case import run_tests
186202

torch/_export/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def aot_load(so_path: str, device: str) -> Callable:
165165
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
166166
elif device == "xpu" or device.startswith("xpu:"):
167167
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg]
168-
168+
elif device == "mps" or device.startswith("mps:"):
169+
runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) # type: ignore[assignment, call-arg]
169170
else:
170171
raise RuntimeError("Unsupported device " + device)
171172

torch/_inductor/codegen/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def init_backend_registration() -> None:
447447
from .cpp_wrapper_cpu import CppWrapperCpu
448448
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
449449
from .cpp_wrapper_gpu import CppWrapperGpu
450+
from .cpp_wrapper_mps import CppWrapperMps
450451
from .cuda_combined_scheduling import CUDACombinedScheduling
451452
from .halide import HalideScheduling
452453
from .mps import MetalScheduling
@@ -494,7 +495,7 @@ def init_backend_registration() -> None:
494495
"mps",
495496
MetalScheduling,
496497
PythonWrapperCodegen,
497-
CppWrapperGpu,
498+
CppWrapperMps,
498499
)
499500

500501
private_backend = torch._C._get_privateuse1_backend_name()

torch/_inductor/codegen/cpp_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
"cpu": "at::kCPU",
8181
"cuda": "at::kCUDA",
8282
"xpu": "at::kXPU",
83+
"mps": "at::kMPS",
8384
}
8485

8586
LAYOUT_TO_ATEN = {

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def _generate_kernel_call_helper(
131131
Only valid when cuda == True.
132132
"""
133133
assert arg_types is not None and len(call_args) == len(arg_types), (
134-
"Mismatch call_args and arg_types in generate_kernel_call"
134+
"Mismatch call_args and arg_types in generate_kernel_call:\n"
135+
f"call_args: {call_args}\n"
136+
f"arg_types: {arg_types}"
135137
)
136138
new_args = []
137139
for idx, arg in enumerate(call_args):
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import Any, Optional
2+
3+
from ..ir import GraphPartitionSignature
4+
from ..virtualized import V
5+
from .cpp_wrapper_gpu import CppWrapperGpu
6+
from .wrapper import PythonWrapperCodegen
7+
8+
9+
class CppWrapperMps(CppWrapperGpu):
10+
@staticmethod
11+
def create(
12+
is_subgraph: bool,
13+
subgraph_name: Optional[str],
14+
parent_wrapper: Optional[PythonWrapperCodegen],
15+
partition_signatures: Optional[GraphPartitionSignature] = None,
16+
) -> "CppWrapperMps":
17+
return CppWrapperMps()
18+
19+
def _generate_kernel_call_helper(
20+
self,
21+
kernel_name: str,
22+
call_args: list[str],
23+
**kwargs: dict[str, Any],
24+
) -> None:
25+
"""
26+
Generates MPS kernel call code. It should look something like:
27+
```
28+
auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
29+
auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
30+
mps_lib_0_func->runCommandBlock([&] {
31+
mps_lib_0_func->startEncoding();
32+
aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0);
33+
aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1);
34+
...
35+
mps_lib_0_func->dispatch(9);
36+
});
37+
```
38+
"""
39+
new_args = []
40+
for idx, arg in enumerate(call_args[:-2]):
41+
new_args.append(
42+
f"aoti_torch_mps_set_arg({kernel_name}_handle, {idx}, {arg});\n"
43+
)
44+
45+
threads, group_size = call_args[-2], call_args[-1]
46+
if threads is None:
47+
raise NotImplementedError("No threads or group_size provided")
48+
elif group_size is None:
49+
new_args.append(f"{kernel_name}->dispatch({threads});\n")
50+
else:
51+
new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n")
52+
53+
# debug printer related logic for cpp kernel type.
54+
debug_printer_manager = V.graph.wrapper_code.debug_printer
55+
debug_printer_manager.set_printer_args(
56+
call_args[:-2],
57+
kernel_name,
58+
None,
59+
None,
60+
"cpp",
61+
)
62+
with debug_printer_manager:
63+
self.writeline(self.wrap_kernel_call(kernel_name, new_args))
64+
65+
def wrap_kernel_call(self, name: str, call_args: list[str]) -> str:
66+
lib_name = name[: -len("_func")]
67+
calling_args = " ".join(call_args)
68+
return f"""
69+
auto {name} = {lib_name}.getKernelFunction("generated_kernel");
70+
auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());
71+
{name}->runCommandBlock([&] {{
72+
{name}->startEncoding();
73+
{calling_args}
74+
}});
75+
"""
76+
77+
@staticmethod
78+
def get_device_include_path(device: str) -> str:
79+
assert V.graph.aot_mode
80+
return (
81+
"#include <torch/csrc/inductor/aoti_include/mps.h>\n"
82+
"#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>"
83+
)

torch/_inductor/codegen/mps.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -775,11 +775,17 @@ def codegen_kernel(self, name: Optional[str] = None) -> str:
775775
"""Called at the end to generate a final kernel string"""
776776
self.codegen_body()
777777
code = IndentedBuffer()
778-
code.writeline("compile_mps_shader('''")
778+
779+
if V.graph.cpp_wrapper:
780+
code.writeline('(R"MTL(')
781+
else:
782+
code.writeline("compile_mps_shader('''")
783+
779784
idx_vars = self.active_range_trees()
780785
with code.indent():
781-
for header in self.headers:
782-
code.writeline(f"#include <c10/metal/{header}.h>")
786+
if not V.graph.cpp_wrapper:
787+
for header in self.headers:
788+
code.writeline(f"#include <c10/metal/{header}.h>")
783789
if self.inside_reduction:
784790
total_reduction_size = math.prod(
785791
t.numel for t in self.range_trees if t.is_reduction
@@ -833,7 +839,11 @@ def codegen_kernel(self, name: Optional[str] = None) -> str:
833839
code.splice(self.indexing_code)
834840
code.splice(self.body)
835841
code.writeline("}")
836-
code.writeline("''')")
842+
843+
if V.graph.cpp_wrapper:
844+
code.writeline(')MTL");')
845+
else:
846+
code.writeline("''')")
837847

838848
return code.getvalue()
839849

@@ -858,15 +868,31 @@ def call_kernel(self, name: str, node: Any = None) -> None:
858868
)
859869
for v in self.active_range_trees()
860870
]
861-
args += [f"threads=[{', '.join(threads)}]"]
871+
872+
if V.graph.cpp_wrapper:
873+
args += [f"{', '.join(threads)}"]
874+
else:
875+
args += [f"threads=[{', '.join(threads)}]"]
876+
else:
877+
if V.graph.cpp_wrapper:
878+
raise RuntimeError("We should always have threads?")
879+
862880
if self.inside_reduction:
863881
threads = [
864882
self.pexpr(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc]
865883
if v.is_reduction
866884
else "1"
867885
for v in self.active_range_trees()
868886
]
869-
args += [f"group_size=[{', '.join(threads)}]"]
887+
if V.graph.cpp_wrapper:
888+
args += [f"{{{', '.join(threads)}}}"]
889+
else:
890+
args += [f"group_size=[{', '.join(threads)}]"]
891+
else:
892+
if V.graph.cpp_wrapper:
893+
# Add a None so that we always have a group_size in the
894+
# arguments. We won't use it if the value is None.
895+
args += [None] # type: ignore[list-item]
870896

871897
wrapper.generate_kernel_call(
872898
name,
@@ -900,9 +926,10 @@ def __init__(self, scheduler: Optional[Scheduler]) -> None:
900926
super().__init__(scheduler)
901927
wrapper = V.graph.wrapper_code
902928
if wrapper is not None:
903-
wrapper.header.splice(
904-
"from torch._inductor.runtime.runtime_utils import compile_mps_shader"
905-
)
929+
if not V.graph.cpp_wrapper:
930+
wrapper.header.splice(
931+
"from torch._inductor.runtime.runtime_utils import compile_mps_shader"
932+
)
906933

907934
def define_kernel(
908935
self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel
@@ -914,10 +941,19 @@ def define_kernel(
914941
# TODO: Merge multiple kernels into a single library
915942
# Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling
916943
mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}"
917-
kernel_name = f"{mps_lib_name}.generated_kernel"
944+
945+
if V.graph.cpp_wrapper:
946+
src_code = (
947+
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}"
948+
+ src_code
949+
)
950+
kernel_name = f"{mps_lib_name}_func"
951+
else:
952+
kernel_name = f"{mps_lib_name}.generated_kernel"
953+
918954
wrapper.src_to_kernel[src_code] = kernel_name
919955
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
920956
metadata_comment = f"{origins}\n{detailed_origins}"
921-
wrapper.define_kernel(mps_lib_name, src_code, metadata_comment)
957+
wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False)
922958

923959
return kernel_name

torch/_inductor/codegen/mps_device_op_overrides.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,13 @@ def set_device(self, device_idx: int) -> str:
1212
assert device_idx == 0
1313
return "pass # MPS set device"
1414

15+
def kernel_driver(self) -> str:
16+
return """
17+
#include <ATen/native/mps/MetalShaderLibrary.h>
18+
"""
19+
20+
def cpp_kernel_type(self) -> str:
21+
return "MTLFunction_t"
22+
1523

1624
register_device_op_overrides("mps", MPSDeviceOpOverrides())

torch/_inductor/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2475,7 +2475,7 @@ def is_gpu(device: Optional[str]) -> bool:
24752475

24762476

24772477
def device_need_guard(device: str) -> bool:
2478-
return is_gpu(device)
2478+
return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now
24792479

24802480

24812481
def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool:

torch/csrc/inductor/aoti_runner/model_container_runner_mps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#if !defined(C10_MOBILE) && !defined(ANDROID)
1+
#if defined(__APPLE__)
22
#pragma once
33

44
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>

0 commit comments

Comments
 (0)