Skip to content

Commit bd1bb97

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Make IOManager use Module instead of Method. (pytorch#13542)
Summary: Let's not expose Method from Module so that it's not getting misused beyond its owner. Differential Revision: D80595261
1 parent 9359481 commit bd1bb97

File tree

10 files changed

+134
-180
lines changed

10 files changed

+134
-180
lines changed

examples/models/llava/runner/llava_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
4242
const float temperature = 0.8f)
4343
: temperature_(temperature),
4444
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45-
io_manager_(std::make_unique<IOManager>()),
45+
io_manager_(std::make_unique<IOManager>(*module_)),
4646
tokenizer_path_(tokenizer_path) {
4747
ET_LOG(
4848
Info,

extension/llm/runner/io_manager/io_manager.h

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88

99
#pragma once
1010

11-
#include <vector>
12-
11+
#include <executorch/extension/module/module.h>
1312
#include <executorch/extension/tensor/tensor.h>
14-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15-
#include <executorch/runtime/executor/method.h>
16-
#include <executorch/runtime/executor/method_meta.h>
1713

1814
namespace executorch {
1915
namespace extension {
@@ -29,6 +25,9 @@ namespace llm {
2925
*/
3026
class ET_EXPERIMENTAL IOManager {
3127
public:
28+
29+
explicit IOManager(ET_MODULE_NAMESPACE::Module& module) : module_(module) {}
30+
3231
/**
3332
* @brief Virtual destructor to allow proper cleanup in derived classes.
3433
*/
@@ -38,88 +37,111 @@ class ET_EXPERIMENTAL IOManager {
3837
* @brief Load the IO manager with method metadata for prefill and
3938
* decode operations.
4039
*
41-
* @param program The program prefill and decode methods are loaded from.
4240
* @param prefill_method The prefill method to initialize with.
4341
* @param decode_method The decode method to initialize with.
4442
*/
4543
ET_NODISCARD virtual runtime::Error load(
46-
const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49-
(void)program;
44+
const std::string& prefill_method,
45+
const std::string& decode_method) {
5046
(void)prefill_method;
5147
(void)decode_method;
5248
return runtime::Error::Ok;
5349
}
5450

51+
ET_NODISCARD runtime::Error load() {
52+
return load("forward", "forward");
53+
}
54+
5555
/**
5656
* @brief Reset the IO manager state.
5757
*
5858
* @param prefill_method The prefill method to reset with.
5959
* @param decode_method The decode method to reset with.
6060
*/
6161
ET_NODISCARD virtual runtime::Error reset(
62-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
63-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
62+
const std::string& prefill_method,
63+
const std::string& decode_method) {
6464
(void)prefill_method;
6565
(void)decode_method;
6666
return runtime::Error::Ok;
6767
}
6868

69+
ET_NODISCARD runtime::Error reset() {
70+
return reset("forward", "forward");
71+
}
72+
6973
/**
7074
* @brief Prepare inputs for the prefill phase of LLM inference.
7175
*
7276
* @param input The input tensor containing token IDs.
7377
* @param start_pos The tensor containing the starting position of the current
7478
* input within the context.
7579
* @param prefill_method The prefill method to prepare inputs for.
76-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
80+
* @return std::vector<runtime::EValue> Vector of prepared inputs
7781
* for the prefill method.
7882
*/
79-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80-
prepare_prefill(
81-
const executorch::extension::TensorPtr& input,
82-
const executorch::extension::TensorPtr& start_pos,
83-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84-
if (prefill_method.inputs_size() != 2) {
83+
virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
84+
const TensorPtr& input,
85+
const TensorPtr& start_pos,
86+
const std::string& prefill_method) {
87+
auto method_meta = module_.method_meta(prefill_method);
88+
if (!method_meta.ok()) {
89+
return method_meta.error();
90+
}
91+
if (method_meta->num_inputs() != 2) {
8592
ET_LOG(
8693
Error,
8794
"Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
88-
prefill_method.inputs_size());
95+
method_meta->num_inputs());
8996
return runtime::Error::InvalidState;
9097
}
9198
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
9299
// here.
93100
return std::vector<runtime::EValue>{input, start_pos};
94101
}
95102

103+
runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
104+
const TensorPtr& input,
105+
const TensorPtr& start_pos) {
106+
return prepare_prefill(input, start_pos, "forward");
107+
}
108+
96109
/**
97110
* @brief Prepare inputs for the decode phase of LLM inference.
98111
*
99112
* @param input The input tensor containing token IDs.
100113
* @param start_pos The tensor containing the starting position of the current
101114
* input within the context.
102115
* @param decode_method The decode method to prepare inputs for.
103-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
116+
* @return std::vector<runtime::EValue> Vector of prepared inputs
104117
* for the decode method.
105118
*/
106-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107-
prepare_decode(
108-
const executorch::extension::TensorPtr& input,
109-
const executorch::extension::TensorPtr& start_pos,
110-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111-
if (decode_method.inputs_size() != 2) {
119+
virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode(
120+
const TensorPtr& input,
121+
const TensorPtr& start_pos,
122+
const std::string& decode_method) {
123+
auto method_meta = module_.method_meta(decode_method);
124+
if (!method_meta.ok()) {
125+
return method_meta.error();
126+
}
127+
if (method_meta->num_inputs() != 2) {
112128
ET_LOG(
113129
Error,
114130
"Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
115-
decode_method.inputs_size());
131+
method_meta->num_inputs());
116132
return runtime::Error::InvalidState;
117133
}
118134
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119135
// here.
120136
return std::vector<runtime::EValue>{input, start_pos};
121137
}
122138

139+
runtime::Result<std::vector<runtime::EValue>> prepare_decode(
140+
const TensorPtr& input,
141+
const TensorPtr& start_pos) {
142+
return prepare_decode(input, start_pos, "forward");
143+
}
144+
123145
/**
124146
* @brief Process and update internal state with outputs from the prefill
125147
* phase.
@@ -128,14 +150,19 @@ class ET_EXPERIMENTAL IOManager {
128150
* @param model_outputs Vector of outputs from the prefill method execution.
129151
*/
130152
ET_NODISCARD virtual runtime::Error update_prefill(
131-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132-
const std::vector<executorch::runtime::EValue>& model_outputs) {
133-
(void)prefill_method;
153+
const std::vector<runtime::EValue>& model_outputs,
154+
const std::string& prefill_method) {
134155
(void)model_outputs;
156+
(void)prefill_method;
135157
// No post inference work to do.
136158
return runtime::Error::Ok;
137159
}
138160

161+
ET_NODISCARD runtime::Error update_prefill(
162+
const std::vector<runtime::EValue>& model_outputs) {
163+
return update_prefill(model_outputs, "forward");
164+
}
165+
139166
/**
140167
* @brief Process and update internal state with outputs from the decode
141168
* phase.
@@ -144,13 +171,21 @@ class ET_EXPERIMENTAL IOManager {
144171
* @param model_outputs Vector of outputs from the decode method execution.
145172
*/
146173
ET_NODISCARD virtual runtime::Error update_decode(
147-
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148-
const std::vector<executorch::runtime::EValue>& model_outputs) {
149-
(void)decode_method;
174+
const std::vector<runtime::EValue>& model_outputs,
175+
const std::string& decode_method) {
150176
(void)model_outputs;
177+
(void)decode_method;
151178
// No post inference work to do.
152179
return runtime::Error::Ok;
153180
}
181+
182+
ET_NODISCARD runtime::Error update_decode(
183+
const std::vector<runtime::EValue>& model_outputs) {
184+
return update_decode(model_outputs, "forward");
185+
}
186+
187+
private:
188+
ET_MODULE_NAMESPACE::Module& module_;
154189
};
155190

156191
} // namespace llm

extension/llm/runner/io_manager/targets.bzl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ def define_common_targets():
1111
exported_headers = [
1212
"io_manager.h",
1313
],
14-
deps = [
14+
exported_deps = [
1515
"//executorch/extension/tensor:tensor" + aten_suffix,
16-
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17-
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
16+
"//executorch/extension/module:module" + aten_suffix,
1817
],
1918
visibility = [
2019
"@EXECUTORCH_CLIENTS",

extension/llm/runner/io_manager/test/TARGETS

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@ define_common_targets()
1010

1111
runtime.cxx_test(
1212
name = "test_io_manager",
13-
srcs = ["test_io_manager.cpp"],
13+
srcs = [
14+
"test_io_manager.cpp",
15+
],
1416
deps = [
1517
"//executorch/extension/llm/runner/io_manager:io_manager",
16-
"//executorch/extension/llm/runner/io_manager:io_manager",
17-
"//executorch/extension/module:module",
18-
"//executorch/extension/tensor:tensor",
19-
"//executorch/runtime/executor:program",
20-
"//executorch/kernels/portable:generated_lib",
18+
"//executorch/kernels/portable:generated_lib",
2119
],
2220
env = {
2321
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",

0 commit comments

Comments
 (0)