|
5 | 5 | #ifdef USE_XPU |
6 | 6 | #include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h> |
7 | 7 | #endif |
| 8 | +#ifdef __APPLE__ |
| 9 | +#include <torch/csrc/inductor/aoti_runner/model_container_runner_mps.h> |
| 10 | +#endif |
8 | 11 | #include <torch/csrc/inductor/aoti_runner/pybind.h> |
9 | 12 | #include <torch/csrc/inductor/aoti_torch/tensor_converter.h> |
10 | 13 | #include <torch/csrc/inductor/aoti_torch/utils.h> |
@@ -130,6 +133,41 @@ void initAOTIRunnerBindings(PyObject* module) { |
130 | 133 | "free_inactive_constant_buffer", |
131 | 134 | &AOTIModelContainerRunnerXpu::free_inactive_constant_buffer); |
132 | 135 |
|
| 136 | +#endif |
| 137 | +#ifdef __APPLE__ |
| 138 | + py::class_<AOTIModelContainerRunnerMps>(m, "AOTIModelContainerRunnerMps") |
| 139 | + .def(py::init<const std::string&, int>()) |
| 140 | + .def( |
| 141 | + "run", |
| 142 | + &AOTIModelContainerRunnerMps::run, |
| 143 | + py::arg("inputs"), |
| 144 | + py::arg("stream_handle") = nullptr) |
| 145 | + .def("get_call_spec", &AOTIModelContainerRunnerMps::get_call_spec) |
| 146 | + .def( |
| 147 | + "get_constant_names_to_original_fqns", |
| 148 | + &AOTIModelContainerRunnerMps::getConstantNamesToOriginalFQNs) |
| 149 | + .def( |
| 150 | + "get_constant_names_to_dtypes", |
| 151 | + &AOTIModelContainerRunnerMps::getConstantNamesToDtypes) |
| 152 | + .def( |
| 153 | + "extract_constants_map", |
| 154 | + &AOTIModelContainerRunnerMps::extract_constants_map) |
| 155 | + .def( |
| 156 | + "update_constant_buffer", |
| 157 | + static_cast<void (AOTIModelContainerRunnerMps::*)( |
| 158 | + std::unordered_map<std::string, at::Tensor>&, bool, bool, bool)>( |
| 159 | + &AOTIModelContainerRunnerMps::update_constant_buffer), |
| 160 | + py::arg("tensor_map"), |
| 161 | + py::arg("use_inactive"), |
| 162 | + py::arg("validate_full_updates"), |
| 163 | + py::arg("user_managed") = false) |
| 164 | + .def( |
| 165 | + "swap_constant_buffer", |
| 166 | + &AOTIModelContainerRunnerMps::swap_constant_buffer) |
| 167 | + .def( |
| 168 | + "free_inactive_constant_buffer", |
| 169 | + &AOTIModelContainerRunnerMps::free_inactive_constant_buffer); |
| 170 | + |
133 | 171 | #endif |
134 | 172 |
|
135 | 173 | m.def( |
|
0 commit comments