Skip to content

Commit c351fa5

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Make IOManager use Module instead of Method. (pytorch#13542)
Summary: Let's not expose Method from Module so that it's not getting misused beyond its owner. Differential Revision: D80595261
1 parent 45e7810 commit c351fa5

File tree

10 files changed

+134
-180
lines changed

10 files changed

+134
-180
lines changed

examples/models/llava/runner/llava_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
4242
const float temperature = 0.8f)
4343
: temperature_(temperature),
4444
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45-
io_manager_(std::make_unique<IOManager>()),
45+
io_manager_(std::make_unique<IOManager>(*module_)),
4646
tokenizer_path_(tokenizer_path) {
4747
ET_LOG(
4848
Info,

extension/llm/runner/io_manager/io_manager.h

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88

99
#pragma once
1010

11-
#include <vector>
12-
11+
#include <executorch/extension/module/module.h>
1312
#include <executorch/extension/tensor/tensor.h>
14-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15-
#include <executorch/runtime/executor/method.h>
16-
#include <executorch/runtime/executor/method_meta.h>
1713

1814
namespace executorch {
1915
namespace extension {
@@ -29,6 +25,8 @@ namespace llm {
2925
*/
3026
class ET_EXPERIMENTAL IOManager {
3127
public:
28+
explicit IOManager(ET_MODULE_NAMESPACE::Module& module) : module_(module) {}
29+
3230
/**
3331
* @brief Virtual destructor to allow proper cleanup in derived classes.
3432
*/
@@ -38,88 +36,111 @@ class ET_EXPERIMENTAL IOManager {
3836
* @brief Load the IO manager with method metadata for prefill and
3937
* decode operations.
4038
*
41-
* @param program The program prefill and decode methods are loaded from.
4239
* @param prefill_method The prefill method to initialize with.
4340
* @param decode_method The decode method to initialize with.
4441
*/
4542
ET_NODISCARD virtual runtime::Error load(
46-
const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49-
(void)program;
43+
const std::string& prefill_method,
44+
const std::string& decode_method) {
5045
(void)prefill_method;
5146
(void)decode_method;
5247
return runtime::Error::Ok;
5348
}
5449

50+
ET_NODISCARD runtime::Error load() {
51+
return load("forward", "forward");
52+
}
53+
5554
/**
5655
* @brief Reset the IO manager state.
5756
*
5857
* @param prefill_method The prefill method to reset with.
5958
* @param decode_method The decode method to reset with.
6059
*/
6160
ET_NODISCARD virtual runtime::Error reset(
62-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
63-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
61+
const std::string& prefill_method,
62+
const std::string& decode_method) {
6463
(void)prefill_method;
6564
(void)decode_method;
6665
return runtime::Error::Ok;
6766
}
6867

68+
ET_NODISCARD runtime::Error reset() {
69+
return reset("forward", "forward");
70+
}
71+
6972
/**
7073
* @brief Prepare inputs for the prefill phase of LLM inference.
7174
*
7275
* @param input The input tensor containing token IDs.
7376
* @param start_pos The tensor containing the starting position of the current
7477
* input within the context.
7578
* @param prefill_method The prefill method to prepare inputs for.
76-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
79+
* @return std::vector<runtime::EValue> Vector of prepared inputs
7780
* for the prefill method.
7881
*/
79-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80-
prepare_prefill(
81-
const executorch::extension::TensorPtr& input,
82-
const executorch::extension::TensorPtr& start_pos,
83-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84-
if (prefill_method.inputs_size() != 2) {
82+
virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
83+
const TensorPtr& input,
84+
const TensorPtr& start_pos,
85+
const std::string& prefill_method) {
86+
auto method_meta = module_.method_meta(prefill_method);
87+
if (!method_meta.ok()) {
88+
return method_meta.error();
89+
}
90+
if (method_meta->num_inputs() != 2) {
8591
ET_LOG(
8692
Error,
8793
"Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
88-
prefill_method.inputs_size());
94+
method_meta->num_inputs());
8995
return runtime::Error::InvalidState;
9096
}
9197
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
9298
// here.
9399
return std::vector<runtime::EValue>{input, start_pos};
94100
}
95101

102+
runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
103+
const TensorPtr& input,
104+
const TensorPtr& start_pos) {
105+
return prepare_prefill(input, start_pos, "forward");
106+
}
107+
96108
/**
97109
* @brief Prepare inputs for the decode phase of LLM inference.
98110
*
99111
* @param input The input tensor containing token IDs.
100112
* @param start_pos The tensor containing the starting position of the current
101113
* input within the context.
102114
* @param decode_method The decode method to prepare inputs for.
103-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
115+
* @return std::vector<runtime::EValue> Vector of prepared inputs
104116
* for the decode method.
105117
*/
106-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107-
prepare_decode(
108-
const executorch::extension::TensorPtr& input,
109-
const executorch::extension::TensorPtr& start_pos,
110-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111-
if (decode_method.inputs_size() != 2) {
118+
virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode(
119+
const TensorPtr& input,
120+
const TensorPtr& start_pos,
121+
const std::string& decode_method) {
122+
auto method_meta = module_.method_meta(decode_method);
123+
if (!method_meta.ok()) {
124+
return method_meta.error();
125+
}
126+
if (method_meta->num_inputs() != 2) {
112127
ET_LOG(
113128
Error,
114129
"Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
115-
decode_method.inputs_size());
130+
method_meta->num_inputs());
116131
return runtime::Error::InvalidState;
117132
}
118133
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119134
// here.
120135
return std::vector<runtime::EValue>{input, start_pos};
121136
}
122137

138+
runtime::Result<std::vector<runtime::EValue>> prepare_decode(
139+
const TensorPtr& input,
140+
const TensorPtr& start_pos) {
141+
return prepare_decode(input, start_pos, "forward");
142+
}
143+
123144
/**
124145
* @brief Process and update internal state with outputs from the prefill
125146
* phase.
@@ -128,14 +149,19 @@ class ET_EXPERIMENTAL IOManager {
128149
* @param model_outputs Vector of outputs from the prefill method execution.
129150
*/
130151
ET_NODISCARD virtual runtime::Error update_prefill(
131-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132-
const std::vector<executorch::runtime::EValue>& model_outputs) {
133-
(void)prefill_method;
152+
const std::vector<runtime::EValue>& model_outputs,
153+
const std::string& prefill_method) {
134154
(void)model_outputs;
155+
(void)prefill_method;
135156
// No post inference work to do.
136157
return runtime::Error::Ok;
137158
}
138159

160+
ET_NODISCARD runtime::Error update_prefill(
161+
const std::vector<runtime::EValue>& model_outputs) {
162+
return update_prefill(model_outputs, "forward");
163+
}
164+
139165
/**
140166
* @brief Process and update internal state with outputs from the decode
141167
* phase.
@@ -144,13 +170,21 @@ class ET_EXPERIMENTAL IOManager {
144170
* @param model_outputs Vector of outputs from the decode method execution.
145171
*/
146172
ET_NODISCARD virtual runtime::Error update_decode(
147-
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148-
const std::vector<executorch::runtime::EValue>& model_outputs) {
149-
(void)decode_method;
173+
const std::vector<runtime::EValue>& model_outputs,
174+
const std::string& decode_method) {
150175
(void)model_outputs;
176+
(void)decode_method;
151177
// No post inference work to do.
152178
return runtime::Error::Ok;
153179
}
180+
181+
ET_NODISCARD runtime::Error update_decode(
182+
const std::vector<runtime::EValue>& model_outputs) {
183+
return update_decode(model_outputs, "forward");
184+
}
185+
186+
private:
187+
ET_MODULE_NAMESPACE::Module& module_;
154188
};
155189

156190
} // namespace llm

extension/llm/runner/io_manager/targets.bzl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ def define_common_targets():
1111
exported_headers = [
1212
"io_manager.h",
1313
],
14-
deps = [
14+
exported_deps = [
1515
"//executorch/extension/tensor:tensor" + aten_suffix,
16-
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17-
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
16+
"//executorch/extension/module:module" + aten_suffix,
1817
],
1918
visibility = [
2019
"@EXECUTORCH_CLIENTS",

extension/llm/runner/io_manager/test/TARGETS

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@ define_common_targets()
1010

1111
runtime.cxx_test(
1212
name = "test_io_manager",
13-
srcs = ["test_io_manager.cpp"],
13+
srcs = [
14+
"test_io_manager.cpp",
15+
],
1416
deps = [
1517
"//executorch/extension/llm/runner/io_manager:io_manager",
16-
"//executorch/extension/llm/runner/io_manager:io_manager",
17-
"//executorch/extension/module:module",
18-
"//executorch/extension/tensor:tensor",
19-
"//executorch/runtime/executor:program",
20-
"//executorch/kernels/portable:generated_lib",
18+
"//executorch/kernels/portable:generated_lib",
2119
],
2220
env = {
2321
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",

0 commit comments

Comments
 (0)