Skip to content

Commit b33b7d5

Browse files
angelayipytorchmergebot
authored andcommitted
[aoti] Add MPS runner and shim (pytorch#153964)
Added AOTIModelContainerRunnerMps and a shim for mps fallback ops. I also added a mps-specific shim which contains one operator, which will be used to set arguments being passed to the Metal kernel: ``` AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg( AOTIMetalKernelFunctionHandle func, unsigned idx, AtenTensorHandle tensor); ``` Pull Request resolved: pytorch#153964 Approved by: https://github.com/malfet, https://github.com/desertfire
1 parent 269fa80 commit b33b7d5

File tree

15 files changed

+275
-1
lines changed

15 files changed

+275
-1
lines changed

caffe2/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
275275
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterLazy.cpp"
276276
)
277277
endif()
278+
if(USE_MPS)
279+
list(APPEND GENERATED_CXX_TORCH
280+
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_mps.cpp"
281+
)
282+
endif()
278283
endif()
279284

280285
set(GENERATED_H_TORCH
@@ -703,6 +708,8 @@ list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
703708

704709
if(USE_MPS)
705710
list(APPEND Caffe2_CPU_SRCS ${Caffe2_MPS_SRCS})
711+
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.cpp)
712+
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_runner/model_container_runner_mps.cpp)
706713
if(CAN_COMPILE_METAL)
707714
file(TOUCH ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)
708715
list(APPEND Caffe2_CPU_SRCS ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)

torch/_C/_aoti.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def alloc_tensor_by_stealing_from_void_ptr(
1919
class AOTIModelContainerRunnerCpu: ...
2020
class AOTIModelContainerRunnerCuda: ...
2121
class AOTIModelContainerRunnerXpu: ...
22+
class AOTIModelContainerRunnerMps: ...
2223

2324
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
2425
class AOTIModelPackageLoader: ...
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#pragma once
2+
3+
#include <torch/csrc/inductor/aoti_include/common.h>
4+
#include <torch/csrc/inductor/cpp_wrapper/device_internal/mps.h>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#if defined(__APPLE__)
2+
#include <torch/csrc/inductor/aoti_runner/model_container_runner_mps.h>
3+
4+
namespace torch::inductor {
5+
6+
AOTIModelContainerRunnerMps::AOTIModelContainerRunnerMps(
7+
const std::string& model_so_path,
8+
size_t num_models,
9+
bool run_single_threaded)
10+
: AOTIModelContainerRunner(
11+
model_so_path,
12+
num_models,
13+
"mps",
14+
"",
15+
run_single_threaded) {}
16+
17+
AOTIModelContainerRunnerMps::~AOTIModelContainerRunnerMps() = default;
18+
19+
namespace {
20+
std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_mps(
21+
const std::string& model_so_path,
22+
size_t num_models,
23+
const std::string& device_str,
24+
const std::string& cubin_dir,
25+
const bool run_single_threaded) {
26+
if (device_str != "mps") {
27+
throw std::runtime_error("Incorrect device passed to aoti_runner_mps");
28+
}
29+
return std::make_unique<AOTIModelContainerRunnerMps>(
30+
model_so_path, num_models, run_single_threaded);
31+
}
32+
} // namespace
33+
34+
static RegisterAOTIModelRunner register_mps_runner(
35+
"mps",
36+
&create_aoti_runner_mps);
37+
38+
} // namespace torch::inductor
39+
#endif
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#if !defined(C10_MOBILE) && !defined(ANDROID)
2+
#pragma once
3+
4+
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
5+
6+
namespace torch::inductor {
7+
class TORCH_API AOTIModelContainerRunnerMps : public AOTIModelContainerRunner {
8+
public:
9+
AOTIModelContainerRunnerMps(
10+
const std::string& model_so_path,
11+
size_t num_models = 1,
12+
const bool run_single_threaded = false);
13+
14+
~AOTIModelContainerRunnerMps() override;
15+
};
16+
17+
} // namespace torch::inductor
18+
#endif

torch/csrc/inductor/aoti_runner/pybind.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef USE_XPU
66
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
77
#endif
8+
#ifdef __APPLE__
9+
#include <torch/csrc/inductor/aoti_runner/model_container_runner_mps.h>
10+
#endif
811
#include <torch/csrc/inductor/aoti_runner/pybind.h>
912
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
1013
#include <torch/csrc/inductor/aoti_torch/utils.h>
@@ -130,6 +133,41 @@ void initAOTIRunnerBindings(PyObject* module) {
130133
"free_inactive_constant_buffer",
131134
&AOTIModelContainerRunnerXpu::free_inactive_constant_buffer);
132135

136+
#endif
137+
#ifdef __APPLE__
138+
py::class_<AOTIModelContainerRunnerMps>(m, "AOTIModelContainerRunnerMps")
139+
.def(py::init<const std::string&, int>())
140+
.def(
141+
"run",
142+
&AOTIModelContainerRunnerMps::run,
143+
py::arg("inputs"),
144+
py::arg("stream_handle") = nullptr)
145+
.def("get_call_spec", &AOTIModelContainerRunnerMps::get_call_spec)
146+
.def(
147+
"get_constant_names_to_original_fqns",
148+
&AOTIModelContainerRunnerMps::getConstantNamesToOriginalFQNs)
149+
.def(
150+
"get_constant_names_to_dtypes",
151+
&AOTIModelContainerRunnerMps::getConstantNamesToDtypes)
152+
.def(
153+
"extract_constants_map",
154+
&AOTIModelContainerRunnerMps::extract_constants_map)
155+
.def(
156+
"update_constant_buffer",
157+
static_cast<void (AOTIModelContainerRunnerMps::*)(
158+
std::unordered_map<std::string, at::Tensor>&, bool, bool, bool)>(
159+
&AOTIModelContainerRunnerMps::update_constant_buffer),
160+
py::arg("tensor_map"),
161+
py::arg("use_inactive"),
162+
py::arg("validate_full_updates"),
163+
py::arg("user_managed") = false)
164+
.def(
165+
"swap_constant_buffer",
166+
&AOTIModelContainerRunnerMps::swap_constant_buffer)
167+
.def(
168+
"free_inactive_constant_buffer",
169+
&AOTIModelContainerRunnerMps::free_inactive_constant_buffer);
170+
133171
#endif
134172

135173
m.def(

torch/csrc/inductor/aoti_runtime/model.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ inline void parse_device_str(
100100
const std::string& device_str,
101101
int32_t& device_type,
102102
int32_t& device_idx) {
103-
std::regex re("(cpu|cuda|xpu)(:([0-9]+))?");
103+
std::regex re("(cpu|cuda|xpu|mps)(:([0-9]+))?");
104104
std::smatch sm;
105105
bool matched = std::regex_match(device_str, sm, re);
106106
AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str);
@@ -112,6 +112,10 @@ inline void parse_device_str(
112112
#ifdef USE_XPU
113113
} else if (sm[1].str() == "xpu") {
114114
device_type = aoti_torch_device_type_xpu();
115+
#endif
116+
#ifdef __APPLE__
117+
} else if (sm[1].str() == "mps") {
118+
device_type = aoti_torch_device_type_mps();
115119
#endif
116120
} else {
117121
AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str);

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
106106
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
107107
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_meta();
108108
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_xpu();
109+
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_mps();
109110
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1();
110111

111112
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef AOTI_TORCH_SHIM_MPS
2+
#define AOTI_TORCH_SHIM_MPS
3+
4+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
5+
6+
#ifdef __cplusplus
7+
extern "C" {
8+
#endif
9+
10+
struct AOTIMetalKernelFunctionOpaque;
11+
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
12+
13+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
14+
AOTIMetalKernelFunctionHandle func,
15+
unsigned idx,
16+
AtenTensorHandle tensor);
17+
18+
#ifdef __cplusplus
19+
} // extern "C"
20+
#endif
21+
22+
#endif // AOTI_TORCH_SHIM_MPS

0 commit comments

Comments
 (0)