Skip to content

Commit aef8f50

Browse files
authored
Merge pull request swiftlang#34210 from jckarter/async-await-sil-verifier
SIL: Verify invariants of async_continuation instructions.
2 parents 59e8043 + 3364c51 commit aef8f50

File tree

6 files changed

+151
-53
lines changed

6 files changed

+151
-53
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,10 @@ class SILInstruction
514514
"Operand does not belong to a SILInstruction");
515515
return isTypeDependentOperand(Op.getOperandNumber());
516516
}
517+
518+
/// Returns true if evaluation of this instruction may cause suspension of an
519+
/// async task.
520+
bool maySuspend() const;
517521

518522
private:
519523
/// Predicate used to filter OperandValueRange.
@@ -3039,12 +3043,34 @@ class KeyPathPattern final
30393043
}
30403044
};
30413045

3042-
/// Accesses the continuation for an async task, to prepare a primitive suspend operation.
3046+
/// Base class for instructions that access the continuation of an async task,
3047+
/// in order to set up a suspension.
30433048
/// The continuation must be consumed by an AwaitAsyncContinuation instruction locally,
30443049
/// and must dynamically be resumed exactly once during the program's ensuing execution.
3050+
class GetAsyncContinuationInstBase
3051+
: public SingleValueInstruction
3052+
{
3053+
protected:
3054+
using SingleValueInstruction::SingleValueInstruction;
3055+
3056+
public:
3057+
/// Get the type of the value the async task receives on a resume.
3058+
CanType getFormalResumeType() const;
3059+
SILType getLoweredResumeType() const;
3060+
3061+
/// True if the continuation can be used to resume the task by throwing an error.
3062+
bool throws() const;
3063+
3064+
static bool classof(const SILNode *I) {
3065+
return I->getKind() >= SILNodeKind::First_GetAsyncContinuationInstBase &&
3066+
I->getKind() <= SILNodeKind::Last_GetAsyncContinuationInstBase;
3067+
}
3068+
};
3069+
3070+
/// Accesses the continuation for an async task, to prepare a primitive suspend operation.
30453071
class GetAsyncContinuationInst final
30463072
: public InstructionBase<SILInstructionKind::GetAsyncContinuationInst,
3047-
SingleValueInstruction>
3073+
GetAsyncContinuationInstBase>
30483074
{
30493075
friend SILBuilder;
30503076

@@ -3054,14 +3080,6 @@ class GetAsyncContinuationInst final
30543080
{}
30553081

30563082
public:
3057-
3058-
/// Get the type of the value the async task receives on a resume.
3059-
CanType getFormalResumeType() const;
3060-
SILType getLoweredResumeType() const;
3061-
3062-
/// True if the continuation can be used to resume the task by throwing an error.
3063-
bool throws() const;
3064-
30653083
ArrayRef<Operand> getAllOperands() const { return {}; }
30663084
MutableArrayRef<Operand> getAllOperands() { return {}; }
30673085
};
@@ -3074,23 +3092,14 @@ class GetAsyncContinuationInst final
30743092
/// buffer that receives the incoming value when the continuation is resumed.
30753093
class GetAsyncContinuationAddrInst final
30763094
: public UnaryInstructionBase<SILInstructionKind::GetAsyncContinuationAddrInst,
3077-
SingleValueInstruction>
3095+
GetAsyncContinuationInstBase>
30783096
{
30793097
friend SILBuilder;
30803098
GetAsyncContinuationAddrInst(SILDebugLocation Loc,
30813099
SILValue Operand,
30823100
SILType ContinuationTy)
30833101
: UnaryInstructionBase(Loc, Operand, ContinuationTy)
30843102
{}
3085-
3086-
public:
3087-
3088-
/// Get the type of the value the async task receives on a resume.
3089-
CanType getFormalResumeType() const;
3090-
SILType getLoweredResumeType() const;
3091-
3092-
/// True if the continuation can be used to resume the task by throwing an error.
3093-
bool throws() const;
30943103
};
30953104

30963105
/// Instantiates a key path object.

include/swift/SIL/SILNode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ class alignas(8) SILNode {
452452
/// If this is a SILArgument or a SILInstruction get its parent module,
453453
/// otherwise return null.
454454
SILModule *getModule() const;
455-
455+
456456
/// Pretty-print the node. If the node is an instruction, the output
457457
/// will be valid SIL assembly; otherwise, it will be an arbitrary
458458
/// format suitable for debugging.

include/swift/SIL/SILNodes.def

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,10 +719,12 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
719719
// be tightened, though we want to be careful that passes that try to do
720720
// code motion or eliminate this instruction don't do so without awareness of
721721
// its structural requirements.
722-
SINGLE_VALUE_INST(GetAsyncContinuationInst, get_async_continuation,
723-
SingleValueInstruction, MayHaveSideEffects, MayRelease)
724-
SINGLE_VALUE_INST(GetAsyncContinuationAddrInst, get_async_continuation_addr,
725-
SingleValueInstruction, MayHaveSideEffects, MayRelease)
722+
ABSTRACT_SINGLE_VALUE_INST(GetAsyncContinuationInstBase, SingleValueInstruction)
723+
SINGLE_VALUE_INST(GetAsyncContinuationInst, get_async_continuation,
724+
GetAsyncContinuationInstBase, MayHaveSideEffects, MayRelease)
725+
SINGLE_VALUE_INST(GetAsyncContinuationAddrInst, get_async_continuation_addr,
726+
GetAsyncContinuationInstBase, MayHaveSideEffects, MayRelease)
727+
SINGLE_VALUE_INST_RANGE(GetAsyncContinuationInstBase, GetAsyncContinuationInst, GetAsyncContinuationAddrInst)
726728

727729
// Key paths
728730
// TODO: The only "side effect" is potentially retaining the returned key path

lib/SIL/IR/SILInstruction.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,21 @@ MultipleValueInstruction *MultipleValueInstructionResult::getParent() {
15191519
return reinterpret_cast<MultipleValueInstruction *>(value);
15201520
}
15211521

1522+
/// Returns true if evaluation of this node may cause suspension of an
1523+
/// async task.
1524+
bool SILInstruction::maySuspend() const {
1525+
// await_async_continuation always suspends the current task.
1526+
if (isa<AwaitAsyncContinuationInst>(this))
1527+
return true;
1528+
1529+
// Fully applying an async function may suspend the caller.
1530+
if (auto applySite = FullApplySite::isa(const_cast<SILInstruction*>(this))) {
1531+
return applySite.getOrigCalleeType()->isAsync();
1532+
}
1533+
1534+
return false;
1535+
}
1536+
15221537
#ifndef NDEBUG
15231538

15241539
//---
@@ -1536,7 +1551,7 @@ MultipleValueInstruction *MultipleValueInstructionResult::getParent() {
15361551
// Check that all subclasses of MultipleValueInstructionResult are the same size
15371552
// as MultipleValueInstructionResult.
15381553
//
1539-
// If this changes, we just need to expand the size fo SILInstructionResultArray
1554+
// If this changes, we just need to expand the size of SILInstructionResultArray
15401555
// to contain a stride. But we assume this now so we should enforce it.
15411556
#define MULTIPLE_VALUE_INST_RESULT(ID, PARENT) \
15421557
static_assert( \

lib/SIL/IR/SILInstructions.cpp

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2931,12 +2931,12 @@ DestructureTupleInst *DestructureTupleInst::create(const SILFunction &F,
29312931
DestructureTupleInst(M, Loc, Operand, Types, OwnershipKinds);
29322932
}
29332933

2934-
CanType GetAsyncContinuationInst::getFormalResumeType() const {
2934+
CanType GetAsyncContinuationInstBase::getFormalResumeType() const {
29352935
// The resume type is the type argument to the continuation type.
29362936
return getType().castTo<BoundGenericType>().getGenericArgs()[0];
29372937
}
29382938

2939-
SILType GetAsyncContinuationInst::getLoweredResumeType() const {
2939+
SILType GetAsyncContinuationInstBase::getLoweredResumeType() const {
29402940
// The lowered resume type is the maximally-abstracted lowering of the
29412941
// formal resume type.
29422942
auto formalType = getFormalResumeType();
@@ -2945,29 +2945,8 @@ SILType GetAsyncContinuationInst::getLoweredResumeType() const {
29452945
return M.Types.getLoweredType(AbstractionPattern::getOpaque(), formalType, c);
29462946
}
29472947

2948-
bool GetAsyncContinuationInst::throws() const {
2948+
bool GetAsyncContinuationInstBase::throws() const {
29492949
// The continuation throws if it's an UnsafeThrowingContinuation
29502950
return getType().castTo<BoundGenericType>()->getDecl()
29512951
== getFunction()->getASTContext().getUnsafeThrowingContinuationDecl();
29522952
}
2953-
2954-
CanType GetAsyncContinuationAddrInst::getFormalResumeType() const {
2955-
// The resume type is the type argument to the continuation type.
2956-
return getType().castTo<BoundGenericType>().getGenericArgs()[0];
2957-
}
2958-
2959-
SILType GetAsyncContinuationAddrInst::getLoweredResumeType() const {
2960-
// The lowered resume type is the maximally-abstracted lowering of the
2961-
// formal resume type.
2962-
auto formalType = getFormalResumeType();
2963-
auto &M = getFunction()->getModule();
2964-
auto c = getFunction()->getTypeExpansionContext();
2965-
return M.Types.getLoweredType(AbstractionPattern::getOpaque(), formalType, c);
2966-
}
2967-
2968-
bool GetAsyncContinuationAddrInst::throws() const {
2969-
// The continuation throws if it's an UnsafeThrowingContinuation
2970-
return getType().castTo<BoundGenericType>()->getDecl()
2971-
== getFunction()->getASTContext().getUnsafeThrowingContinuationDecl();
2972-
}
2973-

lib/SIL/Verifier/SILVerifier.cpp

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
662662

663663
// Used for dominance checking within a basic block.
664664
llvm::DenseMap<const SILInstruction *, unsigned> InstNumbers;
665-
665+
666666
DeadEndBlocks DEBlocks;
667667
LoadBorrowNeverInvalidatedAnalysis loadBorrowNeverInvalidatedAnalysis;
668668
bool SingleFunction = true;
@@ -4748,6 +4748,72 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
47484748
"Type of witness instruction does not match actual type of "
47494749
"witnessed function");
47504750
}
4751+
4752+
void checkGetAsyncContinuationInstBase(GetAsyncContinuationInstBase *GACI) {
4753+
auto resultTy = GACI->getType();
4754+
auto &C = resultTy.getASTContext();
4755+
auto resultBGT = resultTy.getAs<BoundGenericType>();
4756+
require(resultBGT, "Instruction type must be a continuation type");
4757+
auto resultDecl = resultBGT->getDecl();
4758+
require(resultDecl == C.getUnsafeContinuationDecl()
4759+
|| resultDecl == C.getUnsafeThrowingContinuationDecl(),
4760+
"Instruction type must be a continuation type");
4761+
}
4762+
4763+
void checkGetAsyncContinuationInst(GetAsyncContinuationInst *GACI) {
4764+
checkGetAsyncContinuationInstBase(GACI);
4765+
}
4766+
4767+
void checkGetAsyncContinuationAddrInst(GetAsyncContinuationAddrInst *GACI) {
4768+
checkGetAsyncContinuationInstBase(GACI);
4769+
4770+
requireSameType(GACI->getOperand()->getType(),
4771+
GACI->getLoweredResumeType().getAddressType(),
4772+
"Operand type must match continuation resume type");
4773+
}
4774+
4775+
void checkAwaitAsyncContinuationInst(AwaitAsyncContinuationInst *AACI) {
4776+
// The operand must be a GetAsyncContinuation* instruction.
4777+
auto cont = dyn_cast<GetAsyncContinuationInstBase>(AACI->getOperand());
4778+
require(cont, "can only await the result of a get_async_continuation instruction");
4779+
bool isAddressForm = isa<GetAsyncContinuationAddrInst>(cont);
4780+
4781+
auto &C = cont->getType().getASTContext();
4782+
4783+
// The shape of the successors depends on the continuation instruction being
4784+
// awaited.
4785+
require((bool)AACI->getErrorBB() == cont->throws(),
4786+
"must have an error successor if and only if the continuation is throwing");
4787+
if (cont->throws()) {
4788+
require(AACI->getErrorBB()->getNumArguments() == 1,
4789+
"error successor must take one argument");
4790+
auto arg = AACI->getErrorBB()->getArgument(0);
4791+
auto errorType = C.getErrorDecl()->getDeclaredType()->getCanonicalType();
4792+
requireSameType(arg->getType(),
4793+
SILType::getPrimitiveObjectType(errorType),
4794+
"error successor argument must have Error type");
4795+
4796+
if (AACI->getFunction()->hasOwnership()) {
4797+
require(arg->getOwnershipKind() == ValueOwnershipKind::Owned,
4798+
"error successor argument must be owned");
4799+
}
4800+
}
4801+
if (isAddressForm) {
4802+
require(AACI->getResumeBB()->getNumArguments() == 0,
4803+
"resume successor must take no arguments for get_async_continuation_addr");
4804+
} else {
4805+
require(AACI->getResumeBB()->getNumArguments() == 1,
4806+
"resume successor must take one argument for get_async_continuation");
4807+
auto arg = AACI->getResumeBB()->getArgument(0);
4808+
4809+
requireSameType(arg->getType(), cont->getLoweredResumeType(),
4810+
"resume successor must take an argument of the continuation resume type");
4811+
if (AACI->getFunction()->hasOwnership()) {
4812+
require(arg->getOwnershipKind() == ValueOwnershipKind::Owned,
4813+
"resume successor argument must be owned");
4814+
}
4815+
}
4816+
}
47514817

47524818
// This verifies that the entry block of a SIL function doesn't have
47534819
// any predecessors and also verifies the entry point arguments.
@@ -4902,13 +4968,17 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
49024968
std::set<SILInstruction*> ActiveOps;
49034969

49044970
CFGState CFG = Normal;
4971+
4972+
GetAsyncContinuationInstBase *GotAsyncContinuation = nullptr;
49054973
};
49064974
};
49074975

49084976
/// Verify the various control-flow-sensitive rules of SIL:
49094977
///
49104978
/// - stack allocations and deallocations must obey a stack discipline
49114979
/// - accesses must be uniquely ended
4980+
/// - async continuations must be awaited before getting the continuation again, suspending
4981+
/// the task, or exiting the function
49124982
/// - flow-sensitive states must be equivalent on all paths into a block
49134983
void verifyFlowSensitiveRules(SILFunction *F) {
49144984
// Do a traversal of the basic blocks.
@@ -4924,6 +4994,20 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
49244994
for (SILInstruction &i : *BB) {
49254995
CurInstruction = &i;
49264996

4997+
if (i.maySuspend()) {
4998+
// Instructions that may suspend an async context must not happen
4999+
// while the continuation is being accessed, with the exception of
5000+
// the AwaitAsyncContinuationInst that completes suspending the task.
5001+
if (auto aaci = dyn_cast<AwaitAsyncContinuationInst>(&i)) {
5002+
require(state.GotAsyncContinuation == aaci->getOperand(),
5003+
"encountered await_async_continuation that doesn't match active gotten continuation");
5004+
state.GotAsyncContinuation = nullptr;
5005+
} else {
5006+
require(!state.GotAsyncContinuation,
5007+
"cannot suspend async task while unawaited continuation is active");
5008+
}
5009+
}
5010+
49275011
if (i.isAllocatingStack()) {
49285012
state.Stack.push_back(cast<SingleValueInstruction>(&i));
49295013

@@ -4946,13 +5030,18 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
49465030
bool present = state.ActiveOps.erase(beginOp);
49475031
require(present, "operation has already been ended");
49485032
}
4949-
5033+
} else if (auto gaci = dyn_cast<GetAsyncContinuationInstBase>(&i)) {
5034+
require(!state.GotAsyncContinuation,
5035+
"get_async_continuation while unawaited continuation is already active");
5036+
state.GotAsyncContinuation = gaci;
49505037
} else if (auto term = dyn_cast<TermInst>(&i)) {
49515038
if (term->isFunctionExiting()) {
49525039
require(state.Stack.empty(),
49535040
"return with stack allocs that haven't been deallocated");
49545041
require(state.ActiveOps.empty(),
49555042
"return with operations still active");
5043+
require(!state.GotAsyncContinuation,
5044+
"return with unawaited async continuation");
49565045

49575046
if (isa<UnwindInst>(term)) {
49585047
require(state.CFG == VerifyFlowSensitiveRulesDetails::YieldUnwind,
@@ -4970,12 +5059,14 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
49705059
}
49715060
}
49725061
}
4973-
5062+
49745063
if (isa<YieldInst>(term)) {
49755064
require(state.CFG != VerifyFlowSensitiveRulesDetails::YieldOnceResume,
49765065
"encountered multiple 'yield's along single path");
49775066
require(state.CFG == VerifyFlowSensitiveRulesDetails::Normal,
49785067
"encountered 'yield' on abnormal CFG path");
5068+
require(!state.GotAsyncContinuation,
5069+
"encountered 'yield' while an unawaited continuation is active");
49795070
}
49805071

49815072
auto successors = term->getSuccessors();
@@ -5037,6 +5128,8 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
50375128
"inconsistent active-operations sets entering basic block");
50385129
require(state.CFG == foundState.CFG,
50395130
"inconsistent coroutine states entering basic block");
5131+
require(state.GotAsyncContinuation == foundState.GotAsyncContinuation,
5132+
"inconsistent active async continuations entering basic block");
50405133
}
50415134
}
50425135
}

0 commit comments

Comments
 (0)