Skip to content

Commit 8264009

Browse files
committed
[inlining] Fix inlining of coroutines to properly handle owned values passed as guaranteed args.
Previously, we were inserting the end_borrow after the call site, rather than after the end_apply, abort_apply.
1 parent cc60f2c commit 8264009

File tree

3 files changed

+154
-33
lines changed

3 files changed

+154
-33
lines changed

include/swift/SIL/SILBuilder.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class SILBuilder {
174174
setInsertionPoint(I);
175175
}
176176

177+
SILBuilder(SILBasicBlock *BB, const SILDebugScope *DS, SILBuilder &B)
178+
: SILBuilder(BB, DS, B.getBuilderContext()) {}
179+
177180
/// Build instructions before the given insertion point, inheriting the debug
178181
/// location.
179182
///
@@ -2264,6 +2267,11 @@ class SILBuilderWithScope : public SILBuilder {
22642267
inheritScopeFrom(InheritScopeFrom);
22652268
}
22662269

2270+
explicit SILBuilderWithScope(SILBasicBlock *BB, SILBuilder &B,
2271+
SILInstruction *InheritScopeFrom)
2272+
: SILBuilder(BB, InheritScopeFrom->getDebugScope(),
2273+
B.getBuilderContext()) {}
2274+
22672275
/// Creates a new SILBuilder with an insertion point at the
22682276
/// beginning of BB and the debug scope from the first
22692277
/// non-metainstruction in the BB.

lib/SILOptimizer/Utils/SILInliner.cpp

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,20 @@ class BeginApplySite {
102102
return BeginApplySite(BeginApply, Loc, Builder);
103103
}
104104

105-
void preprocess(SILBasicBlock *returnToBB) {
105+
void preprocess(SILBasicBlock *returnToBB,
106+
SmallVectorImpl<SILInstruction *> &endBorrowInsertPts) {
106107
SmallVector<EndApplyInst *, 1> endApplyInsts;
107108
SmallVector<AbortApplyInst *, 1> abortApplyInsts;
108109
BeginApply->getCoroutineEndPoints(endApplyInsts, abortApplyInsts);
109110
while (!endApplyInsts.empty()) {
110111
auto *endApply = endApplyInsts.pop_back_val();
111112
collectEndApply(endApply);
113+
endBorrowInsertPts.push_back(&*std::next(endApply->getIterator()));
112114
}
113115
while (!abortApplyInsts.empty()) {
114-
collectAbortApply(abortApplyInsts.pop_back_val());
116+
auto *abortApply = abortApplyInsts.pop_back_val();
117+
collectAbortApply(abortApply);
118+
endBorrowInsertPts.push_back(&*std::next(abortApply->getIterator()));
115119
}
116120
}
117121

@@ -428,29 +432,38 @@ SILInlineCloner::cloneInline(ArrayRef<SILValue> AppliedArgs) {
428432

429433
SmallVector<SILValue, 4> entryArgs;
430434
entryArgs.reserve(AppliedArgs.size());
435+
SmallBitVector borrowedArgs(AppliedArgs.size());
436+
431437
auto calleeConv = getCalleeFunction()->getConventions();
432-
for (unsigned argIdx = 0, endIdx = AppliedArgs.size(); argIdx < endIdx;
433-
++argIdx) {
434-
SILValue callArg = AppliedArgs[argIdx];
438+
for (auto p : llvm::enumerate(AppliedArgs)) {
439+
SILValue callArg = p.value();
440+
unsigned idx = p.index();
435441
// Insert begin/end borrow for guaranteed arguments.
436-
if (argIdx >= calleeConv.getSILArgIndexOfFirstParam()
437-
&& calleeConv.getParamInfoForSILArg(argIdx).isGuaranteed()) {
438-
callArg = borrowFunctionArgument(callArg, Apply);
442+
if (idx >= calleeConv.getSILArgIndexOfFirstParam() &&
443+
calleeConv.getParamInfoForSILArg(idx).isGuaranteed()) {
444+
if (SILValue newValue = borrowFunctionArgument(callArg, Apply)) {
445+
callArg = newValue;
446+
borrowedArgs[idx] = true;
447+
}
439448
}
440449
entryArgs.push_back(callArg);
441450
}
442451

443452
// Create the return block and set ReturnToBB for use in visitTerminator
444453
// callbacks.
445-
SILBasicBlock *callerBB = Apply.getParent();
454+
SILBasicBlock *callerBlock = Apply.getParent();
455+
SILBasicBlock *throwBlock = nullptr;
456+
SmallVector<SILInstruction *, 1> endBorrowInsertPts;
457+
446458
switch (Apply.getKind()) {
447459
case FullApplySiteKind::ApplyInst: {
448460
auto *AI = dyn_cast<ApplyInst>(Apply);
449461

450462
// Split the BB and do NOT create a branch between the old and new
451463
// BBs; we will create the appropriate terminator manually later.
452464
ReturnToBB =
453-
callerBB->split(std::next(Apply.getInstruction()->getIterator()));
465+
callerBlock->split(std::next(Apply.getInstruction()->getIterator()));
466+
endBorrowInsertPts.push_back(&*ReturnToBB->begin());
454467

455468
// Create an argument on the return-to BB representing the returned value.
456469
auto *retArg =
@@ -459,22 +472,49 @@ SILInlineCloner::cloneInline(ArrayRef<SILValue> AppliedArgs) {
459472
AI->replaceAllUsesWith(retArg);
460473
break;
461474
}
462-
case FullApplySiteKind::BeginApplyInst:
475+
case FullApplySiteKind::BeginApplyInst: {
463476
ReturnToBB =
464-
callerBB->split(std::next(Apply.getInstruction()->getIterator()));
465-
BeginApply->preprocess(ReturnToBB);
477+
callerBlock->split(std::next(Apply.getInstruction()->getIterator()));
478+
// For begin_apply, we insert the end_borrow in the end_apply, abort_apply
479+
// blocks to ensure that our borrowed values live over both the body and
480+
// resume block of our coroutine.
481+
BeginApply->preprocess(ReturnToBB, endBorrowInsertPts);
466482
break;
467-
468-
case FullApplySiteKind::TryApplyInst:
469-
ReturnToBB = cast<TryApplyInst>(Apply)->getNormalBB();
483+
}
484+
case FullApplySiteKind::TryApplyInst: {
485+
auto *tai = cast<TryApplyInst>(Apply);
486+
ReturnToBB = tai->getNormalBB();
487+
endBorrowInsertPts.push_back(&*ReturnToBB->begin());
488+
throwBlock = tai->getErrorBB();
470489
break;
471490
}
491+
}
492+
493+
// Then insert end_borrow in our end borrow block and in the throw
494+
// block if we have one.
495+
if (borrowedArgs.any()) {
496+
for (unsigned i : indices(AppliedArgs)) {
497+
if (!borrowedArgs.test(i)) {
498+
continue;
499+
}
500+
501+
for (auto *insertPt : endBorrowInsertPts) {
502+
SILBuilderWithScope returnBuilder(insertPt, getBuilder());
503+
returnBuilder.createEndBorrow(Apply.getLoc(), entryArgs[i]);
504+
}
505+
506+
if (throwBlock) {
507+
SILBuilderWithScope throwBuilder(throwBlock->begin(), getBuilder());
508+
throwBuilder.createEndBorrow(Apply.getLoc(), entryArgs[i]);
509+
}
510+
}
511+
}
472512

473513
// Visit original BBs in depth-first preorder, starting with the
474514
// entry block, cloning all instructions and terminators.
475515
//
476516
// NextIter is initialized during `fixUp`.
477-
cloneFunctionBody(getCalleeFunction(), callerBB, entryArgs);
517+
cloneFunctionBody(getCalleeFunction(), callerBlock, entryArgs);
478518

479519
// For non-throwing applies, the inlined body now unconditionally branches to
480520
// the returned-to-code, which was previously part of the call site's basic
@@ -560,25 +600,11 @@ SILValue SILInlineCloner::borrowFunctionArgument(SILValue callArg,
560600
FullApplySite AI) {
561601
if (!AI.getFunction()->hasOwnership()
562602
|| callArg.getOwnershipKind() != ValueOwnershipKind::Owned) {
563-
return callArg;
603+
return SILValue();
564604
}
565605

566606
SILBuilderWithScope beginBuilder(AI.getInstruction(), getBuilder());
567-
auto *borrow = beginBuilder.createBeginBorrow(AI.getLoc(), callArg);
568-
if (auto *tryAI = dyn_cast<TryApplyInst>(AI)) {
569-
SILBuilderWithScope returnBuilder(tryAI->getNormalBB()->begin(),
570-
getBuilder());
571-
returnBuilder.createEndBorrow(AI.getLoc(), borrow, callArg);
572-
573-
SILBuilderWithScope throwBuilder(tryAI->getErrorBB()->begin(),
574-
getBuilder());
575-
throwBuilder.createEndBorrow(AI.getLoc(), borrow, callArg);
576-
} else {
577-
SILBuilderWithScope returnBuilder(
578-
std::next(AI.getInstruction()->getIterator()), getBuilder());
579-
returnBuilder.createEndBorrow(AI.getLoc(), borrow, callArg);
580-
}
581-
return borrow;
607+
return beginBuilder.createBeginBorrow(AI.getLoc(), callArg);
582608
}
583609

584610
void SILInlineCloner::visitDebugValueInst(DebugValueInst *Inst) {

test/SILOptimizer/mandatory_inlining_ownership.sil

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,90 @@ bb0(%0 : @owned $Builtin.NativeObject):
414414
%9999 = tuple()
415415
return %9999 : $()
416416
}
417+
418+
///////////////////////
419+
// Begin Apply Tests //
420+
///////////////////////
421+
422+
// Make sure that we do not violate any ownership invariants after inlining this
423+
// code.
424+
425+
sil @get_hidden_int_field_of_klass : $@convention(method) (@guaranteed Klass) -> Builtin.Int32
426+
sil @int_klass_pair_user : $@convention(method) (Builtin.Int32, @guaranteed Klass) -> ()
427+
428+
sil [transparent] [ossa] @begin_apply_callee : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32 {
429+
bb0(%0 : @guaranteed $Klass):
430+
%2 = alloc_stack $Builtin.Int32
431+
%3 = function_ref @get_hidden_int_field_of_klass : $@convention(method) (@guaranteed Klass) -> Builtin.Int32
432+
%4 = apply %3(%0) : $@convention(method) (@guaranteed Klass) -> Builtin.Int32
433+
store %4 to [trivial] %2 : $*Builtin.Int32
434+
yield %2 : $*Builtin.Int32, resume bb1, unwind bb2
435+
436+
bb1:
437+
%7 = load [trivial] %2 : $*Builtin.Int32
438+
%8 = function_ref @int_klass_pair_user : $@convention(method) (Builtin.Int32, @guaranteed Klass) -> ()
439+
%9 = apply %8(%7, %0) : $@convention(method) (Builtin.Int32, @guaranteed Klass) -> ()
440+
dealloc_stack %2 : $*Builtin.Int32
441+
%11 = tuple ()
442+
return %11 : $()
443+
444+
bb2:
445+
%13 = load [trivial] %2 : $*Builtin.Int32
446+
%14 = function_ref @int_klass_pair_user : $@convention(method) (Builtin.Int32, @guaranteed Klass) -> ()
447+
%15 = apply %14(%13, %0) : $@convention(method) (Builtin.Int32, @guaranteed Klass) -> ()
448+
dealloc_stack %2 : $*Builtin.Int32
449+
unwind
450+
}
451+
452+
// CHECK-LABEL: sil [ossa] @begin_apply_caller : $@convention(method) (@guaranteed Klass) -> @error Error {
453+
// CHECK-NOT: begin_apply
454+
// CHECK: } // end sil function 'begin_apply_caller'
455+
sil [ossa] @begin_apply_caller : $@convention(method) (@guaranteed Klass) -> @error Error {
456+
bb0(%0 : @guaranteed $Klass):
457+
%6 = copy_value %0 : $Klass
458+
%12 = function_ref @begin_apply_callee : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32
459+
(%13, %14) = begin_apply %12(%6) : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32
460+
end_apply %14
461+
destroy_value %6 : $Klass
462+
%19 = tuple ()
463+
return %19 : $()
464+
}
465+
466+
// CHECK-LABEL: sil [ossa] @begin_apply_caller_2 : $@convention(method) (@guaranteed Klass) -> @error Error {
467+
// CHECK-NOT: begin_apply
468+
// CHECK: } // end sil function 'begin_apply_caller_2'
469+
sil [ossa] @begin_apply_caller_2 : $@convention(method) (@guaranteed Klass) -> @error Error {
470+
bb0(%0 : @guaranteed $Klass):
471+
%6 = copy_value %0 : $Klass
472+
%12 = function_ref @begin_apply_callee : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32
473+
(%13, %14) = begin_apply %12(%6) : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32
474+
abort_apply %14
475+
destroy_value %6 : $Klass
476+
%19 = tuple ()
477+
return %19 : $()
478+
}
479+
480+
// CHECK-LABEL: sil [ossa] @begin_apply_caller_3 : $@convention(method) (@guaranteed Klass) -> @error Error {
481+
// CHECK-NOT: begin_apply
482+
// CHECK: } // end sil function 'begin_apply_caller_3'
483+
sil [ossa] @begin_apply_caller_3 : $@convention(method) (@guaranteed Klass) -> @error Error {
484+
bb0(%0 : @guaranteed $Klass):
485+
%6 = copy_value %0 : $Klass
486+
%12 = function_ref @begin_apply_callee : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32
487+
(%13, %14) = begin_apply %12(%6) : $@yield_once @convention(method) (@guaranteed Klass) -> @yields @inout Builtin.Int32
488+
cond_br undef, bb1, bb2
489+
490+
bb1:
491+
end_apply %14
492+
br bb3
493+
494+
bb2:
495+
abort_apply %14
496+
br bb3
497+
498+
bb3:
499+
destroy_value %6 : $Klass
500+
%19 = tuple ()
501+
return %19 : $()
502+
}
503+

0 commit comments

Comments
 (0)