Skip to content

Commit aa44c1d

Browse files
committed
[Distributed] IRGen: Emit an early return upon argument decoding failure
1 parent b82c67d commit aa44c1d

File tree

2 files changed

+101
-20
lines changed

2 files changed

+101
-20
lines changed

lib/IRGen/GenDistributed.cpp

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class DistributedAccessor {
123123
unsigned expectedWitnessTables,
124124
Explosion &arguments);
125125

126+
/// Emit an async return from accessor which does cleanup of
127+
/// all the argument allocations.
128+
void emitReturn(llvm::Value *errorValue);
129+
126130
FunctionPointer getPointerToTarget() const;
127131

128132
Callee getCalleeForDistributedTarget(llvm::Value *self) const;
@@ -131,6 +135,12 @@ class DistributedAccessor {
131135
/// could be used to decode argument values to pass to its invocation.
132136
static ArgumentDecoderInfo findArgumentDecoder(IRGenModule &IGM,
133137
SILFunction *thunk);
138+
139+
/// The result type of the accessor.
140+
SILType getResultType() const;
141+
142+
/// The error type of this accessor.
143+
SILType getErrorType() const;
134144
};
135145

136146
} // end namespace
@@ -316,6 +326,9 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
316326
// substitution Argument -> <argument metadata>
317327
decodeArgs.add(argumentType);
318328

329+
Address calleeErrorSlot;
330+
llvm::Value *decodeError = nullptr;
331+
319332
emission->begin();
320333
{
321334
emission->setArgs(decodeArgs, /*isOutlined=*/false,
@@ -325,14 +338,43 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
325338
emission->emitToExplosion(result, /*isOutlined=*/false);
326339
assert(result.empty());
327340

328-
// TODO: Add error handling a new block that uses `emitAsyncReturn`
329-
// if error slot is non-null.
341+
// Load error from the slot to emit an early return if necessary.
342+
{
343+
SILFunctionConventions conv(ArgumentDecoder.Type, IGM.getSILModule());
344+
SILType errorType =
345+
conv.getSILErrorType(IGM.getMaximalTypeExpansionContext());
346+
347+
calleeErrorSlot =
348+
emission->getCalleeErrorSlot(errorType, /*isCalleeAsync=*/true);
349+
decodeError = IGF.Builder.CreateLoad(calleeErrorSlot);
350+
}
330351
}
331352
emission->end();
332353

333354
// Remember to deallocate later.
334355
AllocatedArguments.push_back(resultValue);
335356

357+
// Check whether the error slot has been set and if so
358+
// emit an early return from accessor.
359+
{
360+
auto contBB = IGF.createBasicBlock("");
361+
auto errorBB = IGF.createBasicBlock("on-error");
362+
363+
auto nullError = llvm::Constant::getNullValue(decodeError->getType());
364+
auto hasError = IGF.Builder.CreateICmpNE(decodeError, nullError);
365+
366+
IGF.Builder.CreateCondBr(hasError, errorBB, contBB);
367+
{
368+
IGF.Builder.emitBlock(errorBB);
369+
// Emit an early return if argument decoding failed.
370+
emitReturn(decodeError);
371+
}
372+
373+
IGF.Builder.emitBlock(contBB);
374+
// Reset value of the slot back to `null`
375+
IGF.Builder.CreateStore(nullError, calleeErrorSlot);
376+
}
377+
336378
switch (param.getConvention()) {
337379
case ParameterConvention::Indirect_In:
338380
case ParameterConvention::Indirect_In_Constant: {
@@ -412,10 +454,28 @@ void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
412454
}
413455
}
414456

457+
void DistributedAccessor::emitReturn(llvm::Value *errorValue) {
458+
// Deallocate all of the copied arguments. Since allocations happened
459+
// on stack they have to be deallocated in reverse order.
460+
{
461+
for (auto alloca = AllocatedArguments.rbegin();
462+
alloca != AllocatedArguments.rend(); ++alloca) {
463+
IGF.emitDeallocateDynamicAlloca(*alloca);
464+
}
465+
}
466+
467+
Explosion voidResult;
468+
469+
Explosion error;
470+
error.add(errorValue);
471+
472+
emitAsyncReturn(IGF, AsyncLayout, getResultType(), AccessorType, voidResult,
473+
error);
474+
}
475+
415476
void DistributedAccessor::emit() {
416477
auto targetTy = Target->getLoweredFunctionType();
417478
SILFunctionConventions targetConv(targetTy, IGF.getSILModule());
418-
SILFunctionConventions accessorConv(AccessorType, IGF.getSILModule());
419479
TypeExpansionContext expansionContext = IGM.getMaximalTypeExpansionContext();
420480

421481
auto params = IGF.collectParameters();
@@ -510,7 +570,7 @@ void DistributedAccessor::emit() {
510570
// using computed argument explosion.
511571
{
512572
Explosion result;
513-
Explosion error;
573+
llvm::Value *targetError = nullptr;
514574

515575
auto callee = getCalleeForDistributedTarget(actorSelf);
516576
auto emission =
@@ -536,27 +596,16 @@ void DistributedAccessor::emit() {
536596
{
537597
assert(targetTy->hasErrorResult());
538598

539-
SILType errorType = accessorConv.getSILErrorType(expansionContext);
540599
Address calleeErrorSlot =
541-
emission->getCalleeErrorSlot(errorType, /*isCalleeAsync=*/true);
542-
error.add(IGF.Builder.CreateLoad(calleeErrorSlot));
600+
emission->getCalleeErrorSlot(getErrorType(), /*isCalleeAsync=*/true);
601+
targetError = IGF.Builder.CreateLoad(calleeErrorSlot);
543602
}
544603

545604
emission->end();
546605

547-
// Deallocate all of the copied arguments. Since allocations happened
548-
// on stack they have to be deallocated in reverse order.
549-
{
550-
while (!AllocatedArguments.empty()) {
551-
auto argument = AllocatedArguments.pop_back_val();
552-
IGF.emitDeallocateDynamicAlloca(argument);
553-
}
554-
}
555-
556-
Explosion voidResult;
557-
emitAsyncReturn(IGF, AsyncLayout,
558-
accessorConv.getSILResultType(expansionContext),
559-
AccessorType, voidResult, error);
606+
// Emit an async return that does allocation cleanup and propagates error
607+
// (if any) back to the caller.
608+
emitReturn(targetError);
560609
}
561610
}
562611

@@ -604,6 +653,16 @@ DistributedAccessor::findArgumentDecoder(IRGenModule &IGM, SILFunction *thunk) {
604653
return {.Type = methodTy, .Fn = methodPtr};
605654
}
606655

656+
SILType DistributedAccessor::getResultType() const {
657+
SILFunctionConventions conv(AccessorType, IGF.getSILModule());
658+
return conv.getSILResultType(IGM.getMaximalTypeExpansionContext());
659+
}
660+
661+
SILType DistributedAccessor::getErrorType() const {
662+
SILFunctionConventions conv(AccessorType, IGF.getSILModule());
663+
return conv.getSILErrorType(IGM.getMaximalTypeExpansionContext());
664+
}
665+
607666
Callee ArgumentDecoderInfo::getCallee(llvm::Value *decoder) const {
608667
CalleeInfo info(Type, Type, SubstitutionMap());
609668
return {std::move(info), Fn, decoder};

test/Distributed/Runtime/distributed_actor_remoteCall.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ distributed actor Greeter {
9696
distributed func genericOptional<T: Codable>(t: T?) {
9797
print("---> T = \(t!), type(of:) = \(type(of: t))")
9898
}
99+
100+
distributed func expectsDecodeError(v: Int???) {
101+
}
99102
}
100103

101104

@@ -203,6 +206,11 @@ class FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvoc
203206
fatalError("Cannot cast argument\(anyArgument) to expected \(Argument.self)")
204207
}
205208

209+
if (argumentIndex == 0 && Argument.self == Int???.self) {
210+
throw ExecuteDistributedTargetError(message: "Failed to decode of Int??? (for a test)")
211+
}
212+
213+
206214
argumentIndex += 1
207215
return argument
208216
}
@@ -244,6 +252,7 @@ let generic3Name = "$s4main7GreeterC8generic31a1b1cyx_Sayq_Gq0_tSeRzSERzSeR_SER_
244252
let generic4Name = "$s4main7GreeterC8generic41a1b1cyx_AA1SVyq_GSayq0_GtSeRzSERzSeR_SER_SeR0_SER0_r1_lFTE"
245253
let generic5Name = "$s4main7GreeterC8generic51a1b1c1dyx_AA1SVyq_Gq0_q1_tSeRzSERzSeR_SER_SeR0_SER0_SeR1_SER1_r2_lFTE"
246254
let genericOptionalName = "$s4main7GreeterC15genericOptional1tyxSg_tSeRzSERzlFTE"
255+
let expectsDecodeErrorName = "$s4main7GreeterC18expectsDecodeError1vySiSgSgSg_tFTE"
247256

248257
func test() async throws {
249258
let system = DefaultDistributedActorSystem()
@@ -419,6 +428,19 @@ func test() async throws {
419428
// CHECK: ---> T = [0.0, 3737844653.0], type(of:) = Optional<Array<Double>>
420429
// CHECK-NEXT: RETURN: ()
421430

431+
let decodeErrInvocation = system.makeInvocationEncoder()
432+
433+
try decodeErrInvocation.recordArgument(42)
434+
try decodeErrInvocation.doneRecording()
435+
436+
try await system.executeDistributedTarget(
437+
on: local,
438+
mangledTargetName: expectsDecodeErrorName,
439+
invocationDecoder: decodeErrInvocation,
440+
handler: FakeResultHandler()
441+
)
442+
// CHECK: ERROR: ExecuteDistributedTargetError(message: "Failed to decode of Int??? (for a test)")
443+
422444
print("done")
423445
// CHECK-NEXT: done
424446
}

0 commit comments

Comments
 (0)