Skip to content

Commit ba04f55

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Make IOManager use Module instead of Method. (pytorch#13542)
Summary: Let's not expose Method from Module so that it's not getting misused beyond its owner. Differential Revision: D80595261
1 parent 53146a4 commit ba04f55

File tree

9 files changed

+130
-175
lines changed

9 files changed

+130
-175
lines changed

examples/models/llava/runner/llava_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
4242
const float temperature = 0.8f)
4343
: temperature_(temperature),
4444
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45-
io_manager_(std::make_unique<IOManager>()),
45+
io_manager_(std::make_unique<IOManager>(*module_)),
4646
tokenizer_path_(tokenizer_path) {
4747
ET_LOG(
4848
Info,

extension/llm/runner/io_manager/io_manager.h

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88

99
#pragma once
1010

11-
#include <vector>
12-
11+
#include <executorch/extension/module/module.h>
1312
#include <executorch/extension/tensor/tensor.h>
14-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15-
#include <executorch/runtime/executor/method.h>
16-
#include <executorch/runtime/executor/method_meta.h>
1713

1814
namespace executorch {
1915
namespace extension {
@@ -29,6 +25,10 @@ namespace llm {
2925
*/
3026
class ET_EXPERIMENTAL IOManager {
3127
public:
28+
29+
explicit IOManager(ET_MODULE_NAMESPACE::Module& module)
30+
: module_(module) {}
31+
3232
/**
3333
* @brief Virtual destructor to allow proper cleanup in derived classes.
3434
*/
@@ -38,88 +38,115 @@ class ET_EXPERIMENTAL IOManager {
3838
* @brief Load the IO manager with method metadata for prefill and
3939
* decode operations.
4040
*
41-
* @param program The program prefill and decode methods are loaded from.
4241
* @param prefill_method The prefill method to initialize with.
4342
* @param decode_method The decode method to initialize with.
4443
*/
4544
ET_NODISCARD virtual runtime::Error load(
46-
const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49-
(void)program;
45+
const std::string& prefill_method,
46+
const std::string& decode_method) {
5047
(void)prefill_method;
5148
(void)decode_method;
5249
return runtime::Error::Ok;
5350
}
5451

52+
ET_NODISCARD runtime::Error load() {
53+
return load("forward", "forward");
54+
}
55+
5556
/**
5657
* @brief Reset the IO manager state.
5758
*
5859
* @param prefill_method The prefill method to reset with.
5960
* @param decode_method The decode method to reset with.
6061
*/
6162
ET_NODISCARD virtual runtime::Error reset(
62-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
63-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
63+
const std::string& prefill_method,
64+
const std::string& decode_method) {
6465
(void)prefill_method;
6566
(void)decode_method;
6667
return runtime::Error::Ok;
6768
}
6869

70+
ET_NODISCARD runtime::Error reset() {
71+
return reset("forward", "forward");
72+
}
73+
6974
/**
7075
* @brief Prepare inputs for the prefill phase of LLM inference.
7176
*
7277
* @param input The input tensor containing token IDs.
7378
* @param start_pos The tensor containing the starting position of the current
7479
* input within the context.
7580
* @param prefill_method The prefill method to prepare inputs for.
76-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
81+
* @return std::vector<runtime::EValue> Vector of prepared inputs
7782
* for the prefill method.
7883
*/
79-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
84+
virtual runtime::Result<std::vector<runtime::EValue>>
8085
prepare_prefill(
81-
const executorch::extension::TensorPtr& input,
82-
const executorch::extension::TensorPtr& start_pos,
83-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84-
if (prefill_method.inputs_size() != 2) {
86+
const TensorPtr& input,
87+
const TensorPtr& start_pos,
88+
const std::string& prefill_method) {
89+
auto method_meta = module_.method_meta(prefill_method);
90+
if (!method_meta.ok()) {
91+
return method_meta.error();
92+
}
93+
if (method_meta->num_inputs() != 2) {
8594
ET_LOG(
8695
Error,
8796
"Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
88-
prefill_method.inputs_size());
97+
method_meta->num_inputs());
8998
return runtime::Error::InvalidState;
9099
}
91100
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
92101
// here.
93102
return std::vector<runtime::EValue>{input, start_pos};
94103
}
95104

105+
runtime::Result<std::vector<runtime::EValue>>
106+
prepare_prefill(
107+
const TensorPtr& input,
108+
const TensorPtr& start_pos) {
109+
return prepare_prefill(input, start_pos, "forward");
110+
}
111+
96112
/**
97113
* @brief Prepare inputs for the decode phase of LLM inference.
98114
*
99115
* @param input The input tensor containing token IDs.
100116
* @param start_pos The tensor containing the starting position of the current
101117
* input within the context.
102118
* @param decode_method The decode method to prepare inputs for.
103-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
119+
* @return std::vector<runtime::EValue> Vector of prepared inputs
104120
* for the decode method.
105121
*/
106-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
122+
virtual runtime::Result<std::vector<runtime::EValue>>
107123
prepare_decode(
108-
const executorch::extension::TensorPtr& input,
109-
const executorch::extension::TensorPtr& start_pos,
110-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111-
if (decode_method.inputs_size() != 2) {
124+
const TensorPtr& input,
125+
const TensorPtr& start_pos,
126+
const std::string& decode_method) {
127+
auto method_meta = module_.method_meta(decode_method);
128+
if (!method_meta.ok()) {
129+
return method_meta.error();
130+
}
131+
if (method_meta->num_inputs() != 2) {
112132
ET_LOG(
113133
Error,
114134
"Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
115-
decode_method.inputs_size());
135+
method_meta->num_inputs());
116136
return runtime::Error::InvalidState;
117137
}
118138
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119139
// here.
120140
return std::vector<runtime::EValue>{input, start_pos};
121141
}
122142

143+
runtime::Result<std::vector<runtime::EValue>>
144+
prepare_decode(
145+
const TensorPtr& input,
146+
const TensorPtr& start_pos) {
147+
return prepare_decode(input, start_pos, "forward");
148+
}
149+
123150
/**
124151
* @brief Process and update internal state with outputs from the prefill
125152
* phase.
@@ -128,14 +155,19 @@ class ET_EXPERIMENTAL IOManager {
128155
* @param model_outputs Vector of outputs from the prefill method execution.
129156
*/
130157
ET_NODISCARD virtual runtime::Error update_prefill(
131-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132-
const std::vector<executorch::runtime::EValue>& model_outputs) {
133-
(void)prefill_method;
158+
const std::vector<runtime::EValue>& model_outputs,
159+
const std::string& prefill_method) {
134160
(void)model_outputs;
161+
(void)prefill_method;
135162
// No post inference work to do.
136163
return runtime::Error::Ok;
137164
}
138165

166+
ET_NODISCARD runtime::Error update_prefill(
167+
const std::vector<runtime::EValue>& model_outputs) {
168+
return update_prefill(model_outputs, "forward");
169+
}
170+
139171
/**
140172
* @brief Process and update internal state with outputs from the decode
141173
* phase.
@@ -144,13 +176,21 @@ class ET_EXPERIMENTAL IOManager {
144176
* @param model_outputs Vector of outputs from the decode method execution.
145177
*/
146178
ET_NODISCARD virtual runtime::Error update_decode(
147-
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148-
const std::vector<executorch::runtime::EValue>& model_outputs) {
149-
(void)decode_method;
179+
const std::vector<runtime::EValue>& model_outputs,
180+
const std::string& decode_method) {
150181
(void)model_outputs;
182+
(void)decode_method;
151183
// No post inference work to do.
152184
return runtime::Error::Ok;
153185
}
186+
187+
ET_NODISCARD runtime::Error update_decode(
188+
const std::vector<runtime::EValue>& model_outputs) {
189+
return update_decode(model_outputs, "forward");
190+
}
191+
192+
private:
193+
ET_MODULE_NAMESPACE::Module& module_;
154194
};
155195

156196
} // namespace llm

extension/llm/runner/io_manager/targets.bzl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ def define_common_targets():
1111
exported_headers = [
1212
"io_manager.h",
1313
],
14-
deps = [
14+
exported_deps = [
1515
"//executorch/extension/tensor:tensor" + aten_suffix,
16-
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17-
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
16+
"//executorch/extension/module:module" + aten_suffix,
1817
],
1918
visibility = [
2019
"@EXECUTORCH_CLIENTS",

extension/llm/runner/io_manager/test/TARGETS

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@ define_common_targets()
1010

1111
runtime.cxx_test(
1212
name = "test_io_manager",
13-
srcs = ["test_io_manager.cpp"],
13+
srcs = [
14+
"test_io_manager.cpp",
15+
],
1416
deps = [
1517
"//executorch/extension/llm/runner/io_manager:io_manager",
16-
"//executorch/extension/llm/runner/io_manager:io_manager",
17-
"//executorch/extension/module:module",
18-
"//executorch/extension/tensor:tensor",
19-
"//executorch/runtime/executor:program",
20-
"//executorch/kernels/portable:generated_lib",
18+
"//executorch/kernels/portable:generated_lib",
2119
],
2220
env = {
2321
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",

0 commit comments

Comments
 (0)