@@ -114,10 +114,41 @@ llvm::orc::SymbolMap RuntimeApiSymbolMap(llvm::orc::MangleAndInterner);
114
114
} // namespace runtime
115
115
116
116
// ----------------------------------------------------------------------------//
117
- // Get compiled function results memory layout .
117
+ // Get compiled function arguments and results memory layouts .
118
118
// ----------------------------------------------------------------------------//
119
119
120
- Expected<Executable::ResultsMemoryLayout> Executable::GetResultsMemoryLayout (
120
+ using ArgumentsMemoryLayout = Executable::ArgumentsMemoryLayout;
121
+ using ResultsMemoryLayout = Executable::ResultsMemoryLayout;
122
+
123
+ Expected<ArgumentsMemoryLayout> Executable::GetArgumentsMemoryLayout (
124
+ const FunctionType& signature) {
125
+ // Size of the arguments pointers array.
126
+ size_t num_args_ptrs = 0 ;
127
+
128
+ // Verify all operands types and record memory requirements.
129
+ for (unsigned i = 0 ; i < signature.num_operands (); ++i) {
130
+ auto * type = signature.operand (i);
131
+
132
+ // Kernel context passed as an opaque pointer.
133
+ if (auto * ctx = dyn_cast<KernelContextOperandType>(type)) {
134
+ ++num_args_ptrs;
135
+ continue ;
136
+ }
137
+
138
+ // Memref passed as: 2 pointers + offset + rank * (size + stride)
139
+ if (auto * memref = llvm::dyn_cast<MemrefType>(type)) {
140
+ num_args_ptrs += 3 + 2 * memref->rank ();
141
+ continue ;
142
+ }
143
+
144
+ return MakeStringError (" unknown operand #" , i,
145
+ " type memory layout: " , *type);
146
+ }
147
+
148
+ return ArgumentsMemoryLayout{num_args_ptrs};
149
+ }
150
+
151
+ Expected<ResultsMemoryLayout> Executable::GetResultsMemoryLayout (
121
152
const FunctionType& signature) {
122
153
// Size of the memory block required for storing results, and offsets for
123
154
// each function result.
@@ -189,21 +220,6 @@ Expected<MemrefDesc> ConvertTensorToMemrefDesc(const Tensor& tensor) {
189
220
// Executable CallFrame initialization.
190
221
// -------------------------------------------------------------------------- //
191
222
192
- // Returns the number of call frame arguments required to pass the `memref` to
193
- // the compiled kernel.
194
- static size_t GetArgsCount (const MemrefDesc& memref) {
195
- // Memref layout: 2 pointers + offset + rank * (size + stride)
196
- return 3 + 2 * memref.sizes .size ();
197
- }
198
-
199
- // Returns the number of call frame arguments required to pass all operands
200
- // to the compiled kernel.
201
- static size_t GetArgsCount (ArrayRef<MemrefDesc> operands) {
202
- size_t n = 0 ;
203
- for (const MemrefDesc& memref : operands) n += GetArgsCount (memref);
204
- return n;
205
- }
206
-
207
223
// Unpack `memref` argument into pointers to the data to be compatible with
208
224
// compiled MLIR function ABI.
209
225
static void AddMemrefArgument (const MemrefDesc& memref, size_t * offset,
@@ -274,8 +290,8 @@ Error Executable::InitializeCallFrame(ArrayRef<MemrefDesc> operands,
274
290
}
275
291
}
276
292
277
- size_t n_args_elems = 1 + GetArgsCount (operands) ;
278
- call_frame->args .resize_for_overwrite (n_args_elems );
293
+ size_t num_args_ptrs = arguments_memory_layout_. num_args_ptrs ;
294
+ call_frame->args .resize_for_overwrite (num_args_ptrs );
279
295
280
296
// Add a placeholder for the kernel context as the first argument.
281
297
call_frame->args [0 ] = nullptr ;
@@ -288,7 +304,7 @@ Error Executable::InitializeCallFrame(ArrayRef<MemrefDesc> operands,
288
304
for (const MemrefDesc& desc : operands)
289
305
AddMemrefArgument (desc, &offset, &call_frame->args );
290
306
291
- assert (offset == n_args_elems &&
307
+ assert (offset == num_args_ptrs &&
292
308
" reserved number of args must match the argument offset" );
293
309
294
310
// Allocate storage for results.
@@ -778,6 +794,11 @@ JitCompilationContext::Instantiate(CompilationOptions opts,
778
794
auto runtime_signature = FunctionType::Convert (runtime_type);
779
795
if (auto err = runtime_signature.takeError ()) return std::move (err);
780
796
797
+ // Get the memory layout fo passing function arguments.
798
+ auto arguments_memory_layout =
799
+ Executable::GetArgumentsMemoryLayout (*runtime_signature);
800
+ if (auto err = arguments_memory_layout.takeError ()) return std::move (err);
801
+
781
802
// Get the memory layout for returning function results.
782
803
auto results_memory_layout =
783
804
Executable::GetResultsMemoryLayout (*runtime_signature);
@@ -858,7 +879,8 @@ JitCompilationContext::Instantiate(CompilationOptions opts,
858
879
return Executable (
859
880
ctx->name ().str (), std::move (memory_mapper), std::move (*engine),
860
881
*kernel_fn, std::move (*signature), std::move (*runtime_signature),
861
- std::move (*results_memory_layout), specialization, time_to_compile);
882
+ std::move (*arguments_memory_layout), std::move (*results_memory_layout),
883
+ specialization, time_to_compile);
862
884
}
863
885
864
886
llvm::Error JitCompilationContext::Specialize (
0 commit comments