88
99#pragma once
1010
11- #include < vector>
12-
11+ #include < executorch/extension/module/module.h>
1312#include < executorch/extension/tensor/tensor.h>
14- #include < executorch/runtime/core/exec_aten/exec_aten.h>
15- #include < executorch/runtime/executor/method.h>
16- #include < executorch/runtime/executor/method_meta.h>
1713
1814namespace executorch {
1915namespace extension {
@@ -29,6 +25,9 @@ namespace llm {
2925 */
3026class ET_EXPERIMENTAL IOManager {
3127 public:
28+
29+ explicit IOManager (ET_MODULE_NAMESPACE::Module& module ) : module_(module ) {}
30+
3231 /* *
3332 * @brief Virtual destructor to allow proper cleanup in derived classes.
3433 */
@@ -38,88 +37,111 @@ class ET_EXPERIMENTAL IOManager {
3837 * @brief Load the IO manager with method metadata for prefill and
3938 * decode operations.
4039 *
41- * @param program The program prefill and decode methods are loaded from.
4240 * @param prefill_method The prefill method to initialize with.
4341 * @param decode_method The decode method to initialize with.
4442 */
4543 ET_NODISCARD virtual runtime::Error load (
46- const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49- (void )program;
44+ const std::string& prefill_method,
45+ const std::string& decode_method) {
5046 (void )prefill_method;
5147 (void )decode_method;
5248 return runtime::Error::Ok;
5349 }
5450
51+ ET_NODISCARD runtime::Error load () {
52+ return load (" forward" , " forward" );
53+ }
54+
5555 /* *
5656 * @brief Reset the IO manager state.
5757 *
5858 * @param prefill_method The prefill method to reset with.
5959 * @param decode_method The decode method to reset with.
6060 */
6161 ET_NODISCARD virtual runtime::Error reset (
62- executorch::ET_RUNTIME_NAMESPACE::Method & prefill_method,
63- executorch::ET_RUNTIME_NAMESPACE::Method & decode_method) {
62+ const std::string & prefill_method,
63+ const std::string & decode_method) {
6464 (void )prefill_method;
6565 (void )decode_method;
6666 return runtime::Error::Ok;
6767 }
6868
69+ ET_NODISCARD runtime::Error reset () {
70+ return reset (" forward" , " forward" );
71+ }
72+
6973 /* *
7074 * @brief Prepare inputs for the prefill phase of LLM inference.
7175 *
7276 * @param input The input tensor containing token IDs.
7377 * @param start_pos The tensor containing the starting position of the current
7478 * input within the context.
7579 * @param prefill_method The prefill method to prepare inputs for.
76- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
80+ * @return std::vector<runtime::EValue> Vector of prepared inputs
7781 * for the prefill method.
7882 */
79- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80- prepare_prefill (
81- const executorch::extension::TensorPtr& input,
82- const executorch::extension::TensorPtr& start_pos,
83- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84- if (prefill_method.inputs_size () != 2 ) {
83+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
84+ const TensorPtr& input,
85+ const TensorPtr& start_pos,
86+ const std::string& prefill_method) {
87+ auto method_meta = module_.method_meta (prefill_method);
88+ if (!method_meta.ok ()) {
89+ return method_meta.error ();
90+ }
91+ if (method_meta->num_inputs () != 2 ) {
8592 ET_LOG (
8693 Error,
8794 " Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
88- prefill_method. inputs_size ());
95+ method_meta-> num_inputs ());
8996 return runtime::Error::InvalidState;
9097 }
9198 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
9299 // here.
93100 return std::vector<runtime::EValue>{input, start_pos};
94101 }
95102
103+ runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
104+ const TensorPtr& input,
105+ const TensorPtr& start_pos) {
106+ return prepare_prefill (input, start_pos, " forward" );
107+ }
108+
96109 /* *
97110 * @brief Prepare inputs for the decode phase of LLM inference.
98111 *
99112 * @param input The input tensor containing token IDs.
100113 * @param start_pos The tensor containing the starting position of the current
101114 * input within the context.
102115 * @param decode_method The decode method to prepare inputs for.
103- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
116+ * @return std::vector<runtime::EValue> Vector of prepared inputs
104117 * for the decode method.
105118 */
106- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107- prepare_decode (
108- const executorch::extension::TensorPtr& input,
109- const executorch::extension::TensorPtr& start_pos,
110- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111- if (decode_method.inputs_size () != 2 ) {
119+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode (
120+ const TensorPtr& input,
121+ const TensorPtr& start_pos,
122+ const std::string& decode_method) {
123+ auto method_meta = module_.method_meta (decode_method);
124+ if (!method_meta.ok ()) {
125+ return method_meta.error ();
126+ }
127+ if (method_meta->num_inputs () != 2 ) {
112128 ET_LOG (
113129 Error,
114130 " Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
115- decode_method. inputs_size ());
131+ method_meta-> num_inputs ());
116132 return runtime::Error::InvalidState;
117133 }
118134 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119135 // here.
120136 return std::vector<runtime::EValue>{input, start_pos};
121137 }
122138
139+ runtime::Result<std::vector<runtime::EValue>> prepare_decode (
140+ const TensorPtr& input,
141+ const TensorPtr& start_pos) {
142+ return prepare_decode (input, start_pos, " forward" );
143+ }
144+
123145 /* *
124146 * @brief Process and update internal state with outputs from the prefill
125147 * phase.
@@ -128,14 +150,19 @@ class ET_EXPERIMENTAL IOManager {
128150 * @param model_outputs Vector of outputs from the prefill method execution.
129151 */
130152 ET_NODISCARD virtual runtime::Error update_prefill (
131- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132- const std::vector<executorch::runtime::EValue>& model_outputs) {
133- (void )prefill_method;
153+ const std::vector<runtime::EValue>& model_outputs,
154+ const std::string& prefill_method) {
134155 (void )model_outputs;
156+ (void )prefill_method;
135157 // No post inference work to do.
136158 return runtime::Error::Ok;
137159 }
138160
161+ ET_NODISCARD runtime::Error update_prefill (
162+ const std::vector<runtime::EValue>& model_outputs) {
163+ return update_prefill (model_outputs, " forward" );
164+ }
165+
139166 /* *
140167 * @brief Process and update internal state with outputs from the decode
141168 * phase.
@@ -144,13 +171,21 @@ class ET_EXPERIMENTAL IOManager {
144171 * @param model_outputs Vector of outputs from the decode method execution.
145172 */
146173 ET_NODISCARD virtual runtime::Error update_decode (
147- const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148- const std::vector<executorch::runtime::EValue>& model_outputs) {
149- (void )decode_method;
174+ const std::vector<runtime::EValue>& model_outputs,
175+ const std::string& decode_method) {
150176 (void )model_outputs;
177+ (void )decode_method;
151178 // No post inference work to do.
152179 return runtime::Error::Ok;
153180 }
181+
182+ ET_NODISCARD runtime::Error update_decode (
183+ const std::vector<runtime::EValue>& model_outputs) {
184+ return update_decode (model_outputs, " forward" );
185+ }
186+
187+ private:
188+ ET_MODULE_NAMESPACE::Module& module_;
154189};
155190
156191} // namespace llm
0 commit comments