88
99#pragma once
1010
11- #include < vector>
12-
11+ #include < executorch/extension/module/module.h>
1312#include < executorch/extension/tensor/tensor.h>
14- #include < executorch/runtime/core/exec_aten/exec_aten.h>
15- #include < executorch/runtime/executor/method.h>
16- #include < executorch/runtime/executor/method_meta.h>
1713
1814namespace executorch {
1915namespace extension {
@@ -29,6 +25,8 @@ namespace llm {
2925 */
3026class ET_EXPERIMENTAL IOManager {
3127 public:
28+ explicit IOManager (ET_MODULE_NAMESPACE::Module& module ) : module_(module ) {}
29+
3230 /* *
3331 * @brief Virtual destructor to allow proper cleanup in derived classes.
3432 */
@@ -38,88 +36,111 @@ class ET_EXPERIMENTAL IOManager {
3836 * @brief Load the IO manager with method metadata for prefill and
3937 * decode operations.
4038 *
41- * @param program The program prefill and decode methods are loaded from.
4239 * @param prefill_method The prefill method to initialize with.
4340 * @param decode_method The decode method to initialize with.
4441 */
4542 ET_NODISCARD virtual runtime::Error load (
46- const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49- (void )program;
43+ const std::string& prefill_method,
44+ const std::string& decode_method) {
5045 (void )prefill_method;
5146 (void )decode_method;
5247 return runtime::Error::Ok;
5348 }
5449
50+ ET_NODISCARD runtime::Error load () {
51+ return load (" forward" , " forward" );
52+ }
53+
5554 /* *
5655 * @brief Reset the IO manager state.
5756 *
5857 * @param prefill_method The prefill method to reset with.
5958 * @param decode_method The decode method to reset with.
6059 */
6160 ET_NODISCARD virtual runtime::Error reset (
62- executorch::ET_RUNTIME_NAMESPACE::Method & prefill_method,
63- executorch::ET_RUNTIME_NAMESPACE::Method & decode_method) {
61+ const std::string & prefill_method,
62+ const std::string & decode_method) {
6463 (void )prefill_method;
6564 (void )decode_method;
6665 return runtime::Error::Ok;
6766 }
6867
68+ ET_NODISCARD runtime::Error reset () {
69+ return reset (" forward" , " forward" );
70+ }
71+
6972 /* *
7073 * @brief Prepare inputs for the prefill phase of LLM inference.
7174 *
7275 * @param input The input tensor containing token IDs.
7376 * @param start_pos The tensor containing the starting position of the current
7477 * input within the context.
7578 * @param prefill_method The prefill method to prepare inputs for.
76- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
79+ * @return std::vector<runtime::EValue> Vector of prepared inputs
7780 * for the prefill method.
7881 */
79- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80- prepare_prefill (
81- const executorch::extension::TensorPtr& input,
82- const executorch::extension::TensorPtr& start_pos,
83- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84- if (prefill_method.inputs_size () != 2 ) {
82+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
83+ const TensorPtr& input,
84+ const TensorPtr& start_pos,
85+ const std::string& prefill_method) {
86+ auto method_meta = module_.method_meta (prefill_method);
87+ if (!method_meta.ok ()) {
88+ return method_meta.error ();
89+ }
90+ if (method_meta->num_inputs () != 2 ) {
8591 ET_LOG (
8692 Error,
8793 " Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
88- prefill_method. inputs_size ());
94+ method_meta-> num_inputs ());
8995 return runtime::Error::InvalidState;
9096 }
9197 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
9298 // here.
9399 return std::vector<runtime::EValue>{input, start_pos};
94100 }
95101
102+ runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
103+ const TensorPtr& input,
104+ const TensorPtr& start_pos) {
105+ return prepare_prefill (input, start_pos, " forward" );
106+ }
107+
96108 /* *
97109 * @brief Prepare inputs for the decode phase of LLM inference.
98110 *
99111 * @param input The input tensor containing token IDs.
100112 * @param start_pos The tensor containing the starting position of the current
101113 * input within the context.
102114 * @param decode_method The decode method to prepare inputs for.
103- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
115+ * @return std::vector<runtime::EValue> Vector of prepared inputs
104116 * for the decode method.
105117 */
106- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107- prepare_decode (
108- const executorch::extension::TensorPtr& input,
109- const executorch::extension::TensorPtr& start_pos,
110- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111- if (decode_method.inputs_size () != 2 ) {
118+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode (
119+ const TensorPtr& input,
120+ const TensorPtr& start_pos,
121+ const std::string& decode_method) {
122+ auto method_meta = module_.method_meta (decode_method);
123+ if (!method_meta.ok ()) {
124+ return method_meta.error ();
125+ }
126+ if (method_meta->num_inputs () != 2 ) {
112127 ET_LOG (
113128 Error,
114129 " Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
115- decode_method. inputs_size ());
130+ method_meta-> num_inputs ());
116131 return runtime::Error::InvalidState;
117132 }
118133 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119134 // here.
120135 return std::vector<runtime::EValue>{input, start_pos};
121136 }
122137
138+ runtime::Result<std::vector<runtime::EValue>> prepare_decode (
139+ const TensorPtr& input,
140+ const TensorPtr& start_pos) {
141+ return prepare_decode (input, start_pos, " forward" );
142+ }
143+
123144 /* *
124145 * @brief Process and update internal state with outputs from the prefill
125146 * phase.
@@ -128,14 +149,19 @@ class ET_EXPERIMENTAL IOManager {
128149 * @param model_outputs Vector of outputs from the prefill method execution.
129150 */
130151 ET_NODISCARD virtual runtime::Error update_prefill (
131- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132- const std::vector<executorch::runtime::EValue>& model_outputs) {
133- (void )prefill_method;
152+ const std::vector<runtime::EValue>& model_outputs,
153+ const std::string& prefill_method) {
134154 (void )model_outputs;
155+ (void )prefill_method;
135156 // No post inference work to do.
136157 return runtime::Error::Ok;
137158 }
138159
160+ ET_NODISCARD runtime::Error update_prefill (
161+ const std::vector<runtime::EValue>& model_outputs) {
162+ return update_prefill (model_outputs, " forward" );
163+ }
164+
139165 /* *
140166 * @brief Process and update internal state with outputs from the decode
141167 * phase.
@@ -144,13 +170,21 @@ class ET_EXPERIMENTAL IOManager {
144170 * @param model_outputs Vector of outputs from the decode method execution.
145171 */
146172 ET_NODISCARD virtual runtime::Error update_decode (
147- const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148- const std::vector<executorch::runtime::EValue>& model_outputs) {
149- (void )decode_method;
173+ const std::vector<runtime::EValue>& model_outputs,
174+ const std::string& decode_method) {
150175 (void )model_outputs;
176+ (void )decode_method;
151177 // No post inference work to do.
152178 return runtime::Error::Ok;
153179 }
180+
181+ ET_NODISCARD runtime::Error update_decode (
182+ const std::vector<runtime::EValue>& model_outputs) {
183+ return update_decode (model_outputs, " forward" );
184+ }
185+
186+ private:
187+ ET_MODULE_NAMESPACE::Module& module_;
154188};
155189
156190} // namespace llm
0 commit comments