Skip to content

Commit aae8198

Browse files
ezhulenevcopybara-github
authored andcommitted
[tfrt:jitrt] Add rt.status support to JitRt API intrinsics
PiperOrigin-RevId: 445052118
1 parent eaf4cbb commit aae8198

File tree

10 files changed

+115
-26
lines changed

10 files changed

+115
-26
lines changed

backends/jitrt/include/tfrt/jitrt/opdefs/rt_base.td

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,23 @@ class RT_Type<string name, string typeMnemonic> : TypeDef<RuntimeDialect,
4444
let mnemonic = typeMnemonic;
4545
}
4646

47+
// -------------------------------------------------------------------------- //
48+
// Types for integrating JitRt kernels with the runtime.
49+
// -------------------------------------------------------------------------- //
50+
4751
// This is an opaque handle to tfrt::jitrt::KernelContextType.
4852
def KernelContextType : RT_Type<"KernelContext", "kernel_context"> {
4953
let summary = "Kernel Context type";
5054
let description = [{
51-
Opaque handle used for interacting with the TFRT run-time.
55+
Opaque handle used for interacting with the JitRt runtime.
56+
}];
57+
}
58+
59+
// This is an opaque handle to tfrt::jitrt::StatusType.
60+
def StatusType : RT_Type<"Status", "status"> {
61+
let summary = "Status type";
62+
let description = [{
63+
A status type returned from the JitRt runtime API intrinsics.
5264
}];
5365
}
5466

backends/jitrt/include/tfrt/jitrt/opdefs/rt_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/BuiltinTypes.h"
2323
#include "mlir/IR/Dialect.h"
2424
#include "mlir/IR/OpDefinition.h"
25+
#include "mlir/IR/OpImplementation.h"
2526
#include "mlir/IR/Types.h"
2627
#include "tfrt/jitrt/opdefs/rt_dialect.h.inc"
2728

backends/jitrt/include/tfrt/jitrt/opdefs/rt_ops.td

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,20 @@ def SetErrorOp : RT_Op<"set_error"> {
117117
let assemblyFormat = "$ctx `,` $error attr-dict";
118118
}
119119

120+
//===----------------------------------------------------------------------===//
121+
// IsOkOp
122+
//===----------------------------------------------------------------------===//
123+
124+
def IsOkOp : RT_Op<"is_ok"> {
125+
let summary = "returns true if status is ok";
126+
let description = "Checks if the runtime status is ok.";
127+
128+
let arguments = (ins StatusType:$status);
129+
let results = (outs I1:$ok);
130+
131+
let assemblyFormat = "$status attr-dict";
132+
}
133+
120134
//===----------------------------------------------------------------------===//
121135
// CustomCallOp
122136
//===----------------------------------------------------------------------===//
@@ -130,13 +144,18 @@ def CustomCallOp : RT_Op<"custom_call"> {
130144
on top of the JitRt, for example this can be used as an extension mechanism
131145
to register vendor specific kernels (e.g. call oneDNN convolution).
132146

147+
Returns `!rt.status` value which can be checked to see if the custom call
148+
was successful.
149+
133150
Example:
134151

135152
```mlir
136153
func @compute(%ctx: !rt.kernel_context, %arg0: memref<?xf32>,
137154
%arg1: memref<?xf32>) {
138-
%0 = rt.custom_call "one_dnn.some_operation"(%arg0, %arg1)
139-
: (memref<?xf32>, memref<?xf32>) -> !one_dnn.status
155+
%status = rt.custom_call "one_dnn.some_operation"(%arg0, %arg1)
156+
: (memref<?xf32>, memref<?xf32>) -> ()
157+
%0 = rt.is_ok %status
158+
cf.assert %0, "failed to call one_dnn custom call"
140159
return
141160
}
142161
```
@@ -152,10 +171,13 @@ def CustomCallOp : RT_Op<"custom_call"> {
152171
Variadic<AnyType>:$operands
153172
);
154173

155-
let results = (outs Variadic<AnyType>);
174+
let results = (outs
175+
StatusType:$status,
176+
Variadic<AnyType>:$results
177+
);
156178

157179
let assemblyFormat = [{
158-
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
180+
$callee `(` $operands `)` attr-dict `:` functional-type($operands, $results)
159181
}];
160182
}
161183

backends/jitrt/include/tfrt/jitrt/runtime.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ extern "C" void *runtimeGetResultStorage(KernelContext *, int64_t);
5656
// Sets kernel context to an error state.
5757
extern "C" void runtimeSetError(KernelContext *, const char *);
5858

59-
// Calls the custom call function registered with the runtime.
60-
extern "C" void runtimeCustomCall(const char *, void **args);
59+
// Calls the custom call function registered with the runtime. Returns true
60+
// if the custom call was successful.
61+
extern "C" bool runtimeCustomCall(const char *, void **args);
6162

6263
} // namespace runtime
6364
} // namespace jitrt

backends/jitrt/lib/conversion/rt_to_llvm.cc

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ struct RuntimeAPI {
113113
static FunctionType CustomCallFunctionType(MLIRContext *ctx) {
114114
auto callee = OpaquePointerType(ctx);
115115
auto args = CustomCallArgumentsType(ctx);
116-
return FunctionType::get(ctx, {callee, args}, {});
116+
auto i1 = IntegerType::get(ctx, 1);
117+
return FunctionType::get(ctx, {callee, args}, {i1});
117118
}
118119
};
119120

@@ -139,11 +140,16 @@ class RuntimeTypeConverter : public TypeConverter {
139140
RuntimeTypeConverter() {
140141
addConversion([](Type type) { return type; });
141142
addConversion(ConvertKernelContextType);
143+
addConversion(ConvertStatusType);
142144
}
143145

144146
static llvm::Optional<Type> ConvertKernelContextType(KernelContextType type) {
145147
return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
146148
}
149+
150+
static llvm::Optional<Type> ConvertStatusType(StatusType type) {
151+
return IntegerType::get(type.getContext(), 1);
152+
}
147153
};
148154

149155
// -------------------------------------------------------------------------- //
@@ -253,6 +259,23 @@ class SetErrorOpLowering : public OpConversionPattern<SetErrorOp> {
253259
}
254260
};
255261

262+
//===----------------------------------------------------------------------===//
263+
// Convert rt.is_ok to the corresponding runtime API call.
264+
//===----------------------------------------------------------------------===//
265+
266+
class IsOkOpLowering : public OpConversionPattern<IsOkOp> {
267+
public:
268+
using OpConversionPattern::OpConversionPattern;
269+
270+
LogicalResult matchAndRewrite(
271+
IsOkOp op, OpAdaptor adaptor,
272+
ConversionPatternRewriter &rewriter) const override {
273+
// Just pass through the converted operand.
274+
rewriter.replaceOp(op, adaptor.status());
275+
return success();
276+
}
277+
};
278+
256279
//===----------------------------------------------------------------------===//
257280
// Convert rt.custom_call to the corresponding runtime API call.
258281
//===----------------------------------------------------------------------===//
@@ -444,7 +467,8 @@ class CustomCallOpLowering : public OpConversionPattern<CustomCallOp> {
444467
if (failed(args)) return op.emitOpError() << "failed to encode arguments";
445468

446469
// Call runtime API to call the custom call target.
447-
rewriter.replaceOpWithNewOp<CallOp>(op, kCustomCall, TypeRange(),
470+
auto i1 = rewriter.getI1Type();
471+
rewriter.replaceOpWithNewOp<CallOp>(op, kCustomCall, TypeRange(i1),
448472
ValueRange({callee, *args}));
449473

450474
return success();
@@ -471,10 +495,11 @@ void ConvertRuntimeToLLVMPass::runOnOperation() {
471495
// We use conversion to LLVM type to lower all runtime operands to LLVM types.
472496
LLVMTypeConverter llvm_converter(ctx);
473497
llvm_converter.addConversion(RuntimeTypeConverter::ConvertKernelContextType);
498+
llvm_converter.addConversion(RuntimeTypeConverter::ConvertStatusType);
474499

475500
// Lower from the runtime operations to the runtime API function calls.
476-
patterns.add<SetOutputOpLowering, SetErrorOpLowering, CustomCallOpLowering>(
477-
llvm_converter, ctx);
501+
patterns.add<SetOutputOpLowering, SetErrorOpLowering, IsOkOpLowering,
502+
CustomCallOpLowering>(llvm_converter, ctx);
478503

479504
// Convert function signatures and call sites.
480505
mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,

backends/jitrt/lib/jitrt.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,7 @@ extern "C" void runtimeSetError(KernelContext* ctx, const char* error) {
12571257
ctx->call_frame->error = {error};
12581258
}
12591259

1260-
extern "C" void runtimeCustomCall(const char* callee, void** args) {
1260+
extern "C" bool runtimeCustomCall(const char* callee, void** args) {
12611261
assert(callee && "callee must be not null");
12621262

12631263
// Default custom calls registry for the JitRt kernels.
@@ -1267,14 +1267,13 @@ extern "C" void runtimeCustomCall(const char* callee, void** args) {
12671267
return registry;
12681268
}();
12691269

1270-
// TODO(ezhulenev): Return failure if custom call is not registered.
12711270
auto* custom_call = registry->Find(callee);
1272-
assert(custom_call && "unknown custom call");
1271+
if (custom_call == nullptr) return false;
12731272

1274-
// TODO(ezhulenev): Handle failures in custom calls.
12751273
auto result = custom_call->call(args);
1276-
assert(mlir::succeeded(result) && "failed custom call");
1277-
(void)result;
1274+
if (mlir::failed(result)) return false;
1275+
1276+
return true;
12781277
}
12791278

12801279
llvm::orc::SymbolMap RuntimeApiSymbolMap(llvm::orc::MangleAndInterner mangle) {

backends/jitrt/mlir_tests/jitrt/compile.assert.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ module @kernels attributes { tfrt.compiled } {
3131
}
3232

3333
// CHECK: --- Running 'runtime_error'
34-
func.func @runtime_error() -> !tfrt.chain {
34+
func.func @runtime_error() -> !t.tensor {
3535
%ch0 = tfrt.new.chain
3636

3737
// Allocate and initialize input tensor.
@@ -40,9 +40,10 @@ func.func @runtime_error() -> !tfrt.chain {
4040

4141
%executable = jitrt.compile { kernel = @kernels::@main }
4242

43-
// expected-error @+1 {{Dimension 0 must have size 0}}
4443
%output = jitrt.execute %executable[%input_ready](%input)
4544
: (!t.tensor) -> (!t.tensor)
4645

47-
tfrt.return %ch0 : !tfrt.chain
46+
// CHECK: returned <<error: compiled kernel run time error:
47+
// CHECK-SAME: Dimension 0 must have size 0>>
48+
tfrt.return %output : !t.tensor
4849
}

backends/jitrt/mlir_tests/jitrt/compile.custom_call.mlir

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ module @kernels attributes { tfrt.compiled } {
2020
%c1 = arith.constant 1 : index
2121
%0 = memref.dim %input, %c0 : memref<?x?xf32>
2222
%1 = memref.dim %input, %c1 : memref<?x?xf32>
23-
%output = memref.alloc(%0, %1) : memref<?x?xf32>
2423

25-
rt.custom_call "testlib.times_two"(%input, %output)
24+
// Reverse dimension order to test invalid custom call arguments below.
25+
%output = memref.alloc(%1, %0) : memref<?x?xf32>
26+
27+
%status = rt.custom_call "testlib.times_two"(%input, %output)
2628
: (memref<?x?xf32>, memref<?x?xf32>) -> ()
29+
%ok = rt.is_ok %status
30+
cf.assert %ok, "failed to call custom call 'testlib.times_two'"
2731

2832
func.return %output : memref<?x?xf32>
2933
}
@@ -55,4 +59,23 @@ func.func @compiled_custom_call() -> !tfrt.chain {
5559
%printed = tfrt.print.i1 %cmp, %cmp_ch
5660

5761
tfrt.return %printed : !tfrt.chain
58-
}
62+
}
63+
64+
// CHECK: --- Running 'compiled_custom_call_error'
65+
func.func @compiled_custom_call_error() -> !t.tensor {
66+
%ch0 = tfrt.new.chain
67+
68+
// Allocate and initialize input tensor.
69+
%input = tfrt_dht.create_uninitialized_tensor.f32.2 [16 : i64, 4 : i64]
70+
%ch1 = tfrt_dht.fill_tensor_with_constant.f32 %input, %ch0 1.0 : f32
71+
72+
// Compile a kernel with a custom call.
73+
%executable = jitrt.compile { kernel = @kernels::@main }
74+
75+
// Execute compiled kernel with tensor operands.
76+
%output = jitrt.execute %executable[%ch1](%input) : (!t.tensor) -> !t.tensor
77+
78+
// CHECK: returned <<error: compiled kernel run time error:
79+
// CHECK-SAME: failed to call custom call 'testlib.times_two'>>
80+
tfrt.return %output : !t.tensor
81+
}

backends/jitrt/mlir_tests/rt/ops.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ func.func @set_error(%arg0: !rt.kernel_context) {
4343
// CHECK: %[[MEMREF:.*]]: memref<?xf32>
4444
func.func @custom_call(%arg0: !rt.kernel_context, %arg1: memref<?xf32>) -> f32 {
4545
// CHECK: rt.custom_call "f32_reduce"(%[[MEMREF]]) : (memref<?xf32>) -> f32
46-
%0 = rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> f32
46+
%status, %0 = rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> f32
47+
%ok = rt.is_ok %status
48+
cf.assert %ok, "failed to call custom call"
4749
func.return %0 : f32
4850
}

backends/jitrt/mlir_tests/rt/rt_to_llvm.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ func.func @custom_call(%arg0: !rt.kernel_context, %arg1: memref<?xf32>) {
7979
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
8080
// CHECK: %[[ARGS:.*]] = llvm.alloca %[[C3]] x !llvm.ptr<i8>
8181

82-
// CHECK: call @runtimeCustomCall(%[[CALLEE]], %[[ARGS]])
83-
rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> ()
82+
// CHECK: %[[STATUS:.*]] = call @runtimeCustomCall(%[[CALLEE]], %[[ARGS]])
83+
// CHECK: cf.assert %[[STATUS]], "oops"
84+
%status = rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> ()
85+
%ok = rt.is_ok %status
86+
cf.assert %ok, "oops"
8487

8588
func.return
8689
}

0 commit comments

Comments
 (0)