88
99#pragma once
1010
11- #include < vector>
12-
11+ #include < executorch/extension/module/module.h>
1312#include < executorch/extension/tensor/tensor.h>
14- #include < executorch/runtime/core/exec_aten/exec_aten.h>
15- #include < executorch/runtime/executor/method.h>
16- #include < executorch/runtime/executor/method_meta.h>
1713
1814namespace executorch {
1915namespace extension {
@@ -29,6 +25,10 @@ namespace llm {
2925 */
3026class ET_EXPERIMENTAL IOManager {
3127 public:
28+
29+ explicit IOManager (ET_MODULE_NAMESPACE::Module& module )
30+ : module_(module ) {}
31+
3232 /* *
3333 * @brief Virtual destructor to allow proper cleanup in derived classes.
3434 */
@@ -38,88 +38,115 @@ class ET_EXPERIMENTAL IOManager {
3838 * @brief Load the IO manager with method metadata for prefill and
3939 * decode operations.
4040 *
41- * @param program The program prefill and decode methods are loaded from.
4241 * @param prefill_method The prefill method to initialize with.
4342 * @param decode_method The decode method to initialize with.
4443 */
4544 ET_NODISCARD virtual runtime::Error load (
46- const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49- (void )program;
45+ const std::string& prefill_method,
46+ const std::string& decode_method) {
5047 (void )prefill_method;
5148 (void )decode_method;
5249 return runtime::Error::Ok;
5350 }
5451
52+ ET_NODISCARD runtime::Error load () {
53+ return load (" forward" , " forward" );
54+ }
55+
5556 /* *
5657 * @brief Reset the IO manager state.
5758 *
5859 * @param prefill_method The prefill method to reset with.
5960 * @param decode_method The decode method to reset with.
6061 */
6162 ET_NODISCARD virtual runtime::Error reset (
62- executorch::ET_RUNTIME_NAMESPACE::Method & prefill_method,
63- executorch::ET_RUNTIME_NAMESPACE::Method & decode_method) {
63+ const std::string & prefill_method,
64+ const std::string & decode_method) {
6465 (void )prefill_method;
6566 (void )decode_method;
6667 return runtime::Error::Ok;
6768 }
6869
70+ ET_NODISCARD runtime::Error reset () {
71+ return reset (" forward" , " forward" );
72+ }
73+
6974 /* *
7075 * @brief Prepare inputs for the prefill phase of LLM inference.
7176 *
7277 * @param input The input tensor containing token IDs.
7378 * @param start_pos The tensor containing the starting position of the current
7479 * input within the context.
7580 * @param prefill_method The prefill method to prepare inputs for.
76- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
81+ * @return std::vector<runtime::EValue> Vector of prepared inputs
7782 * for the prefill method.
7883 */
79- virtual runtime::Result<std::vector<executorch:: runtime::EValue>>
84+ virtual runtime::Result<std::vector<runtime::EValue>>
8085 prepare_prefill (
81- const executorch::extension::TensorPtr& input,
82- const executorch::extension::TensorPtr& start_pos,
83- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84- if (prefill_method.inputs_size () != 2 ) {
86+ const TensorPtr& input,
87+ const TensorPtr& start_pos,
88+ const std::string& prefill_method) {
89+ auto method_meta = module_.method_meta (prefill_method);
90+ if (!method_meta.ok ()) {
91+ return method_meta.error ();
92+ }
93+ if (method_meta->num_inputs () != 2 ) {
8594 ET_LOG (
8695 Error,
8796 " Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
88- prefill_method. inputs_size ());
97+ method_meta-> num_inputs ());
8998 return runtime::Error::InvalidState;
9099 }
91100 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
92101 // here.
93102 return std::vector<runtime::EValue>{input, start_pos};
94103 }
95104
105+ runtime::Result<std::vector<runtime::EValue>>
106+ prepare_prefill (
107+ const TensorPtr& input,
108+ const TensorPtr& start_pos) {
109+ return prepare_prefill (input, start_pos, " forward" );
110+ }
111+
96112 /* *
97113 * @brief Prepare inputs for the decode phase of LLM inference.
98114 *
99115 * @param input The input tensor containing token IDs.
100116 * @param start_pos The tensor containing the starting position of the current
101117 * input within the context.
102118 * @param decode_method The decode method to prepare inputs for.
103- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
119+ * @return std::vector<runtime::EValue> Vector of prepared inputs
104120 * for the decode method.
105121 */
106- virtual runtime::Result<std::vector<executorch:: runtime::EValue>>
122+ virtual runtime::Result<std::vector<runtime::EValue>>
107123 prepare_decode (
108- const executorch::extension::TensorPtr& input,
109- const executorch::extension::TensorPtr& start_pos,
110- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111- if (decode_method.inputs_size () != 2 ) {
124+ const TensorPtr& input,
125+ const TensorPtr& start_pos,
126+ const std::string& decode_method) {
127+ auto method_meta = module_.method_meta (decode_method);
128+ if (!method_meta.ok ()) {
129+ return method_meta.error ();
130+ }
131+ if (method_meta->num_inputs () != 2 ) {
112132 ET_LOG (
113133 Error,
114134 " Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
115- decode_method. inputs_size ());
135+ method_meta-> num_inputs ());
116136 return runtime::Error::InvalidState;
117137 }
118138 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119139 // here.
120140 return std::vector<runtime::EValue>{input, start_pos};
121141 }
122142
143+ runtime::Result<std::vector<runtime::EValue>>
144+ prepare_decode (
145+ const TensorPtr& input,
146+ const TensorPtr& start_pos) {
147+ return prepare_decode (input, start_pos, " forward" );
148+ }
149+
123150 /* *
124151 * @brief Process and update internal state with outputs from the prefill
125152 * phase.
@@ -128,14 +155,19 @@ class ET_EXPERIMENTAL IOManager {
128155 * @param model_outputs Vector of outputs from the prefill method execution.
129156 */
130157 ET_NODISCARD virtual runtime::Error update_prefill (
131- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132- const std::vector<executorch::runtime::EValue>& model_outputs) {
133- (void )prefill_method;
158+ const std::vector<runtime::EValue>& model_outputs,
159+ const std::string& prefill_method) {
134160 (void )model_outputs;
161+ (void )prefill_method;
135162 // No post inference work to do.
136163 return runtime::Error::Ok;
137164 }
138165
166+ ET_NODISCARD runtime::Error update_prefill (
167+ const std::vector<runtime::EValue>& model_outputs) {
168+ return update_prefill (model_outputs, " forward" );
169+ }
170+
139171 /* *
140172 * @brief Process and update internal state with outputs from the decode
141173 * phase.
@@ -144,13 +176,21 @@ class ET_EXPERIMENTAL IOManager {
144176 * @param model_outputs Vector of outputs from the decode method execution.
145177 */
146178 ET_NODISCARD virtual runtime::Error update_decode (
147- const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148- const std::vector<executorch::runtime::EValue>& model_outputs) {
149- (void )decode_method;
179+ const std::vector<runtime::EValue>& model_outputs,
180+ const std::string& decode_method) {
150181 (void )model_outputs;
182+ (void )decode_method;
151183 // No post inference work to do.
152184 return runtime::Error::Ok;
153185 }
186+
187+ ET_NODISCARD runtime::Error update_decode (
188+ const std::vector<runtime::EValue>& model_outputs) {
189+ return update_decode (model_outputs, " forward" );
190+ }
191+
192+ private:
193+ ET_MODULE_NAMESPACE::Module& module_;
154194};
155195
156196} // namespace llm
0 commit comments