Skip to content

Commit 7d98db6

Browse files
ezhulenevcopybara-github
authored andcommitted
[tfrt:jitrt] Compute arguments layout when constructing Executable
name old cpu/op new cpu/op delta BM_GetExecutableUniqueShapes 82.4ns ± 0% 82.6ns ± 0% +0.29% BM_GetExecutableSameShapes 66.8ns ± 0% 67.0ns ± 1% +0.35% BM_GetExecutableKnownShapes 50.6ns ± 0% 51.0ns ± 0% +0.80% BM_InitializeCallFrameUniqueShapes_true 70.6ns ± 2% 68.6ns ± 0% -2.84% BM_InitializeCallFrameSameShapes_true 70.3ns ± 0% 68.6ns ± 0% -2.35% BM_InitializeCallFrameKnownShapes_true 70.8ns ± 0% 69.2ns ± 3% -2.22% BM_InitializeCallFrameUniqueShapes_false 35.7ns ± 0% 25.1ns ± 0% -29.67% BM_InitializeCallFrameSameShapes_false 35.7ns ± 1% 25.1ns ± 0% -29.74% BM_InitializeCallFrameKnownShapes_false 35.7ns ± 1% 25.1ns ± 0% -29.72% PiperOrigin-RevId: 444942223
1 parent df53cdb commit 7d98db6

File tree

2 files changed

+62
-25
lines changed

2 files changed

+62
-25
lines changed

backends/jitrt/include/tfrt/jitrt/jitrt.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,7 @@ class Executable {
776776
using KernelFunctionPtr = void (*)(void**);
777777

778778
// Forward declare types defined below.
779+
struct ArgumentsMemoryLayout;
779780
struct ResultsMemoryLayout;
780781
struct CallFrame;
781782
struct ExecuteOpts;
@@ -786,6 +787,7 @@ class Executable {
786787
std::unique_ptr<mlir::ExecutionEngine> engine,
787788
KernelFunctionPtr fptr, FunctionType signature,
788789
FunctionType runtime_signature,
790+
ArgumentsMemoryLayout arguments_memory_layout,
789791
ResultsMemoryLayout results_memory_layout,
790792
Optional<size_t> specialization,
791793
std::chrono::milliseconds time_to_compile)
@@ -795,6 +797,7 @@ class Executable {
795797
fptr_(fptr),
796798
signature_(std::move(signature)),
797799
runtime_signature_(std::move(runtime_signature)),
800+
arguments_memory_layout_(std::move(arguments_memory_layout)),
798801
results_memory_layout_(std::move(results_memory_layout)),
799802
specialization_(specialization),
800803
time_to_compile_(time_to_compile) {
@@ -887,6 +890,12 @@ class Executable {
887890
llvm::StringRef error;
888891
};
889892

893+
// Requirements for passing arguments to the compiled function.
894+
struct ArgumentsMemoryLayout {
895+
// Currently we always pass arguments as an array of pointers.
896+
size_t num_args_ptrs;
897+
};
898+
890899
// Requirements for the contiguous block of memory to store compiled function
891900
// results. When we invoke a compiled fuction we allocate a block of memory,
892901
// and pass pointers to pre-computed offsets as output arguments to the
@@ -933,10 +942,15 @@ class Executable {
933942
KernelContext* kernel_context;
934943
};
935944

936-
// Verifies that all types in the entrypoint function signature are supported
937-
// at runtime and we know how to pass arguments and fetch results. Returns
938-
// a pre-computed layout for the function results. If some of the operands
939-
// or results are not supported returns an error.
945+
// Verifies that all operands types in the entrypoint function signature are
946+
// supported at run time . Returns a pre-computed layout for the function
947+
// arguments. If some arguments are not supported returns an error.
948+
static Expected<ArgumentsMemoryLayout> GetArgumentsMemoryLayout(
949+
const FunctionType& signature);
950+
951+
// Verifies that all results types in the entrypoint function signature are
952+
// supported at run time . Returns a pre-computed layout for the function
953+
// results. If some results are not supported returns an error.
940954
static Expected<ResultsMemoryLayout> GetResultsMemoryLayout(
941955
const FunctionType& signature);
942956

@@ -973,6 +987,7 @@ class Executable {
973987
// expected by the runtime.
974988
FunctionType runtime_signature_;
975989

990+
ArgumentsMemoryLayout arguments_memory_layout_;
976991
ResultsMemoryLayout results_memory_layout_;
977992

978993
// Specialization id if this executable is a specialization, or an empty

backends/jitrt/lib/jitrt.cc

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,41 @@ llvm::orc::SymbolMap RuntimeApiSymbolMap(llvm::orc::MangleAndInterner);
114114
} // namespace runtime
115115

116116
//----------------------------------------------------------------------------//
117-
// Get compiled function results memory layout.
117+
// Get compiled function arguments and results memory layouts.
118118
//----------------------------------------------------------------------------//
119119

120-
Expected<Executable::ResultsMemoryLayout> Executable::GetResultsMemoryLayout(
120+
using ArgumentsMemoryLayout = Executable::ArgumentsMemoryLayout;
121+
using ResultsMemoryLayout = Executable::ResultsMemoryLayout;
122+
123+
Expected<ArgumentsMemoryLayout> Executable::GetArgumentsMemoryLayout(
124+
const FunctionType& signature) {
125+
// Size of the arguments pointers array.
126+
size_t num_args_ptrs = 0;
127+
128+
// Verify all operands types and record memory requirements.
129+
for (unsigned i = 0; i < signature.num_operands(); ++i) {
130+
auto* type = signature.operand(i);
131+
132+
// Kernel context passed as an opaque pointer.
133+
if (auto* ctx = dyn_cast<KernelContextOperandType>(type)) {
134+
++num_args_ptrs;
135+
continue;
136+
}
137+
138+
// Memref passed as: 2 pointers + offset + rank * (size + stride)
139+
if (auto* memref = llvm::dyn_cast<MemrefType>(type)) {
140+
num_args_ptrs += 3 + 2 * memref->rank();
141+
continue;
142+
}
143+
144+
return MakeStringError("unknown operand #", i,
145+
" type memory layout: ", *type);
146+
}
147+
148+
return ArgumentsMemoryLayout{num_args_ptrs};
149+
}
150+
151+
Expected<ResultsMemoryLayout> Executable::GetResultsMemoryLayout(
121152
const FunctionType& signature) {
122153
// Size of the memory block required for storing results, and offsets for
123154
// each function result.
@@ -189,21 +220,6 @@ Expected<MemrefDesc> ConvertTensorToMemrefDesc(const Tensor& tensor) {
189220
// Executable CallFrame initialization.
190221
// -------------------------------------------------------------------------- //
191222

192-
// Returns the number of call frame arguments required to pass the `memref` to
193-
// the compiled kernel.
194-
static size_t GetArgsCount(const MemrefDesc& memref) {
195-
// Memref layout: 2 pointers + offset + rank * (size + stride)
196-
return 3 + 2 * memref.sizes.size();
197-
}
198-
199-
// Returns the number of call frame arguments required to pass all operands
200-
// to the compiled kernel.
201-
static size_t GetArgsCount(ArrayRef<MemrefDesc> operands) {
202-
size_t n = 0;
203-
for (const MemrefDesc& memref : operands) n += GetArgsCount(memref);
204-
return n;
205-
}
206-
207223
// Unpack `memref` argument into pointers to the data to be compatible with
208224
// compiled MLIR function ABI.
209225
static void AddMemrefArgument(const MemrefDesc& memref, size_t* offset,
@@ -274,8 +290,8 @@ Error Executable::InitializeCallFrame(ArrayRef<MemrefDesc> operands,
274290
}
275291
}
276292

277-
size_t n_args_elems = 1 + GetArgsCount(operands);
278-
call_frame->args.resize_for_overwrite(n_args_elems);
293+
size_t num_args_ptrs = arguments_memory_layout_.num_args_ptrs;
294+
call_frame->args.resize_for_overwrite(num_args_ptrs);
279295

280296
// Add a placeholder for the kernel context as the first argument.
281297
call_frame->args[0] = nullptr;
@@ -288,7 +304,7 @@ Error Executable::InitializeCallFrame(ArrayRef<MemrefDesc> operands,
288304
for (const MemrefDesc& desc : operands)
289305
AddMemrefArgument(desc, &offset, &call_frame->args);
290306

291-
assert(offset == n_args_elems &&
307+
assert(offset == num_args_ptrs &&
292308
"reserved number of args must match the argument offset");
293309

294310
// Allocate storage for results.
@@ -778,6 +794,11 @@ JitCompilationContext::Instantiate(CompilationOptions opts,
778794
auto runtime_signature = FunctionType::Convert(runtime_type);
779795
if (auto err = runtime_signature.takeError()) return std::move(err);
780796

797+
// Get the memory layout fo passing function arguments.
798+
auto arguments_memory_layout =
799+
Executable::GetArgumentsMemoryLayout(*runtime_signature);
800+
if (auto err = arguments_memory_layout.takeError()) return std::move(err);
801+
781802
// Get the memory layout for returning function results.
782803
auto results_memory_layout =
783804
Executable::GetResultsMemoryLayout(*runtime_signature);
@@ -858,7 +879,8 @@ JitCompilationContext::Instantiate(CompilationOptions opts,
858879
return Executable(
859880
ctx->name().str(), std::move(memory_mapper), std::move(*engine),
860881
*kernel_fn, std::move(*signature), std::move(*runtime_signature),
861-
std::move(*results_memory_layout), specialization, time_to_compile);
882+
std::move(*arguments_memory_layout), std::move(*results_memory_layout),
883+
specialization, time_to_compile);
862884
}
863885

864886
llvm::Error JitCompilationContext::Specialize(

0 commit comments

Comments
 (0)