Skip to content

Commit df53cdb

Browse files
ezhulenevcopybara-github
authored andcommitted
[tfrt:jitrt] Add a flag to disable operands verification when executing JitRt program
PiperOrigin-RevId: 444937420
1 parent 4709555 commit df53cdb

File tree

3 files changed

+135
-41
lines changed

3 files changed

+135
-41
lines changed

backends/jitrt/cpp_tests/jit_executable_test.cc

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616

1717
#include <memory>
18+
#include <utility>
1819

1920
#include "benchmark/benchmark.h"
2021
#include "gtest/gtest.h"
@@ -47,9 +48,9 @@ using SymbolicShape = SymbolicShapesResolver::SymbolicShape;
4748

4849
static const char* mlir_module = R"(
4950
func.func @compute(%arg0: memref<?x?xf32>,
50-
%arg1: memref<?x?xf32>,
51-
%arg3: memref<?x?xf32>,
52-
%arg4: memref<16x32xf32>) {
51+
%arg1: memref<?x?xf32>,
52+
%arg3: memref<?x?xf32>,
53+
%arg4: memref<16x32xf32>) {
5354
func.return
5455
})";
5556

@@ -72,6 +73,7 @@ SmallVector<MemrefDesc> GetFakeMemrefs(SmallVector<SymbolicShape> shapes) {
7273
MemrefDesc desc;
7374
desc.dtype = DType::F32;
7475
desc.sizes.insert(desc.sizes.begin(), shape.begin(), shape.end());
76+
desc.strides.append(shape.size(), 0); // we don't need real strides
7577
memrefs.push_back(std::move(desc));
7678
}
7779

@@ -111,6 +113,43 @@ void BenchmarkGetExecutable(benchmark::State& state,
111113
}
112114
}
113115

116+
void BenchmarkInitializeCallFrame(benchmark::State& state,
117+
SmallVector<MemrefDesc> operands,
118+
bool verify) {
119+
auto host = CreateSingleThreadedHostContext();
120+
121+
CompilationOptions opts;
122+
opts.specialization = CompilationOptions::Specialization::kAlways;
123+
opts.register_dialects = RegisterDefaultJitRtDialects;
124+
125+
CompilationPipelineOptions copts;
126+
opts.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
127+
CreateDefaultJitRtCompilationPipeline(pm, copts);
128+
};
129+
130+
llvm::Expected<JitExecutable> jit_executable =
131+
JitExecutable::Instantiate(mlir_module, entrypoint, opts);
132+
if (auto err = jit_executable.takeError()) TFRT_LOG(FATAL) << err;
133+
134+
// Get the executable.
135+
Expected<AsyncValuePtr<Executable>> executable =
136+
jit_executable->GetExecutable(operands);
137+
if (auto err = executable.takeError()) TFRT_LOG(FATAL) << err;
138+
139+
// Check that compilation was successful.
140+
host->Quiesce();
141+
if (executable->IsError()) TFRT_LOG(FATAL) << executable->GetError();
142+
143+
for (auto _ : state) {
144+
Executable::CallFrame call_frame;
145+
auto err =
146+
(*executable)->InitializeCallFrame(operands, &call_frame, verify);
147+
benchmark::DoNotOptimize(call_frame);
148+
}
149+
}
150+
151+
// -------------------------------------------------------------------------- //
152+
114153
#define BM_GetExecutable(NAME, OPERANDS) \
115154
static void BM_GetExecutable##NAME(benchmark::State& state) { \
116155
BenchmarkGetExecutable(state, OPERANDS); \
@@ -126,5 +165,38 @@ BM_GetExecutable(SameShapes,
126165
BM_GetExecutable(KnownShapes,
127166
GetFakeMemrefs({{16, 32}, {16, 32}, {16, 32}, {16, 32}}));
128167

168+
// -------------------------------------------------------------------------- //
169+
170+
#define BM_InitializeCallFrame(NAME, OPERANDS, VERIFY) \
171+
static void BM_InitializeCallFrame##NAME##_##VERIFY( \
172+
benchmark::State& state) { \
173+
BenchmarkInitializeCallFrame(state, OPERANDS, VERIFY); \
174+
} \
175+
BENCHMARK(BM_InitializeCallFrame##NAME##_##VERIFY)
176+
177+
BM_InitializeCallFrame(UniqueShapes,
178+
GetFakeMemrefs({{10, 11}, {12, 13}, {14, 15}, {16, 32}}),
179+
true);
180+
181+
BM_InitializeCallFrame(SameShapes,
182+
GetFakeMemrefs({{10, 11}, {10, 11}, {10, 11}, {16, 32}}),
183+
true);
184+
185+
BM_InitializeCallFrame(KnownShapes,
186+
GetFakeMemrefs({{16, 32}, {16, 32}, {16, 32}, {16, 32}}),
187+
true);
188+
189+
BM_InitializeCallFrame(UniqueShapes,
190+
GetFakeMemrefs({{10, 11}, {12, 13}, {14, 15}, {16, 32}}),
191+
false);
192+
193+
BM_InitializeCallFrame(SameShapes,
194+
GetFakeMemrefs({{10, 11}, {10, 11}, {10, 11}, {16, 32}}),
195+
false);
196+
197+
BM_InitializeCallFrame(KnownShapes,
198+
GetFakeMemrefs({{16, 32}, {16, 32}, {16, 32}, {16, 32}}),
199+
false);
200+
129201
} // namespace
130202
} // namespace tfrt

backends/jitrt/include/tfrt/jitrt/jitrt.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -802,16 +802,22 @@ class Executable {
802802
}
803803

804804
// Initializes call frame by adding all operands as pointers to the arguments
805-
// vector. Also allocates storage for the returned values. Return values
806-
// storage requirements inferred from the kernel function signature.
805+
// vector. Also allocates storage for the returned values.
806+
//
807+
// If `verify_operands` is true (in debug mode it's always on, independent of
808+
// the argument value) this function also verifies that operands passed at run
809+
// time matches the executable entrypoint signature (e.g. all statically known
810+
// dimensions of the memrefs matches the operands). Returns an error if finds
811+
// a mismatch.
807812
//
808813
// This function leaves the kernel context argument (the first argument of a
809814
// kernel function) uninitialized. It will be initialized in the `Execute`
810815
// function right before the actual execution.
811816
//
812817
// See mlir::ExecutionEngine `packFunctionArguments` for the details.
813818
Error InitializeCallFrame(ArrayRef<MemrefDesc> operands,
814-
CallFrame* call_frame) const;
819+
CallFrame* call_frame,
820+
bool verify_operands = true) const;
815821

816822
// Converts returned values owned by the call frame using provided value
817823
// converter. If result conversion fails (e.g. result type is not supported)
@@ -822,16 +828,20 @@ class Executable {
822828
Error ReturnResults(const ReturnValueConverterBase& results,
823829
CallFrame* call_frame) const;
824830

825-
// Executes compiled function with given operands. If operands passed at
826-
// runtime are not compatible with the compiled function signature, allocates
827-
// error async values for all results.
831+
// Executes compiled function with given operands.
832+
//
833+
// If `verify_operands` is true (in debug mode it's always on, independent of
834+
// the argument value) this function also verifies that operands passed at run
835+
// time matches the executable entrypoint signature. If some of the operands
836+
// do not match the expected type, this function allocates error async values
837+
// for all results and returns an error.
828838
//
829839
// Returns compiled function results via the user-provided results converter.
830840
// If compiled function execution completed in the error state, emits error
831841
// async value for all results.
832842
Error Execute(ArrayRef<MemrefDesc> operands,
833843
const ReturnValueConverterBase& results,
834-
const ExecuteOpts& opts) const;
844+
const ExecuteOpts& opts, bool verify_operands = true) const;
835845

836846
// Executes compiled function using user provided call frame.
837847
//

backends/jitrt/lib/jitrt.cc

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -224,41 +224,53 @@ static void AddMemrefArgument(const MemrefDesc& memref, size_t* offset,
224224
for (const Index& stride : memref.strides) add_arg(&stride);
225225
}
226226

227+
// Always verify executable operands in debug mode.
228+
static bool VerifyOperands(bool verify_operands) {
229+
#if defined(NDEBUG)
230+
return verify_operands;
231+
#endif
232+
return true;
233+
}
234+
227235
Error Executable::InitializeCallFrame(ArrayRef<MemrefDesc> operands,
228-
CallFrame* call_frame) const {
236+
CallFrame* call_frame,
237+
bool verify_operands) const {
229238
// TODO(ezhulenev): If executable is specialized for operands shapes then
230239
// there is no need to verify them once more here. However currently we rely
231240
// on a hash code to look up specializations, and this can lead to collisions.
241+
if (VerifyOperands(verify_operands)) {
242+
// We verify run time operands against the run time signature.
243+
const FunctionType& signature = runtime_signature_;
244+
245+
// Make sure that we call the kernel with the correct number of operands.
246+
// We subtract one operand from the signature because it corresponds to the
247+
// context that we prepend to the given operands.
248+
if (LLVM_UNLIKELY(operands.size() != signature.num_operands() - 1))
249+
return MakeStringError(
250+
"number of operands doesn't match the function signature: ",
251+
operands.size(), " vs ", signature.num_operands() - 1);
252+
253+
// Verify that all operands passed at runtime are compatible with compiled
254+
// function signature.
255+
auto kctx = dyn_cast<KernelContextOperandType>(signature.operand(0));
256+
if (LLVM_UNLIKELY(!kctx)) {
257+
return MakeStringError(
258+
"expected KernelContext in first argument of signature, got: ",
259+
signature.operand(0));
260+
}
232261

233-
// Make sure that we call the kernel with the correct number of operands.
234-
// We subtract one operand from the signature because it corresponds to the
235-
// context that we prepend to the given operands.
236-
if (LLVM_UNLIKELY(operands.size() != runtime_signature_.num_operands() - 1))
237-
return MakeStringError(
238-
"number of operands doesn't match the function signature: ",
239-
operands.size(), " vs ", runtime_signature_.num_operands() - 1);
240-
241-
// Verify that all operands passed at runtime are compatible with compiled
242-
// function signature.
243-
auto kctx = dyn_cast<KernelContextOperandType>(runtime_signature_.operand(0));
244-
if (LLVM_UNLIKELY(!kctx)) {
245-
return MakeStringError(
246-
"expected KernelContext in first argument of "
247-
"signature, got: ",
248-
runtime_signature_.operand(0));
249-
}
250-
251-
// We use 0-based index for operands, because the kernel context operand is an
252-
// internal implementation detail, and in case of an error users should get
253-
// back operand index corresponding to the user provided signature.
254-
for (unsigned i = 0; i < operands.size(); ++i) {
255-
unsigned idx = i + 1; // use 1-based index to fetch runtime operand
262+
// We use 0-based index for operands, because the kernel context operand is
263+
// an internal implementation detail, and in case of an error users should
264+
// get back operand index corresponding to the user provided signature.
265+
for (unsigned i = 0; i < operands.size(); ++i) {
266+
unsigned idx = i + 1; // use 1-based index to fetch runtime operand
256267

257-
if (auto* memref = dyn_cast<MemrefType>(runtime_signature_.operand(idx))) {
258-
if (auto err = VerifyMemrefOperand(i, *memref, operands[i])) return err;
259-
} else {
260-
return MakeStringError("expected memref operand at #", i,
261-
", got: ", *runtime_signature_.operand(i));
268+
if (auto* memref = dyn_cast<MemrefType>(signature.operand(idx))) {
269+
if (auto err = VerifyMemrefOperand(i, *memref, operands[i])) return err;
270+
} else {
271+
return MakeStringError("expected memref operand at #", i,
272+
", got: ", *signature.operand(i));
273+
}
262274
}
263275
}
264276

@@ -415,7 +427,7 @@ Error ReturnErrors(const ReturnValueConverterBase& results, Error error) {
415427

416428
Error Executable::Execute(ArrayRef<MemrefDesc> operands,
417429
const ReturnValueConverterBase& results,
418-
const ExecuteOpts& opts) const {
430+
const ExecuteOpts& opts, bool verify_operands) const {
419431
// CallFrame can be allocated on the stack because compiled function will
420432
// unpack all the arguments it needs, and async regions will not access
421433
// the data after the initial function will return the result.
@@ -445,7 +457,7 @@ Error Executable::Execute(ArrayRef<MemrefDesc> operands,
445457

446458
// Compiled function takes arguments and results as `void**` type erased
447459
// pointer. See mlir::ExecutionEngine `packFunctionArguments` for the details.
448-
if (auto err = InitializeCallFrame(operands, &call_frame))
460+
if (auto err = InitializeCallFrame(operands, &call_frame, verify_operands))
449461
return ReturnErrors(results, std::move(err));
450462

451463
Execute(call_frame, opts);

0 commit comments

Comments
 (0)