Skip to content

Commit a2dbdec

Browse files
committed
SIL: Refactor get_async_continuation[_addr] to return a RawUnsafeContinuation
1 parent c2bfff0 commit a2dbdec

File tree

18 files changed

+153
-167
lines changed

18 files changed

+153
-167
lines changed

include/swift/SIL/SILBuilder.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,17 +1939,27 @@ class SILBuilder {
19391939
//===--------------------------------------------------------------------===//
19401940

19411941
GetAsyncContinuationInst *createGetAsyncContinuation(SILLocation Loc,
1942-
SILType ContinuationTy) {
1942+
CanType ResumeType,
1943+
bool Throws) {
1944+
auto ContinuationType = SILType::getPrimitiveObjectType(
1945+
getASTContext().TheRawUnsafeContinuationType);
19431946
return insert(new (getModule()) GetAsyncContinuationInst(getSILDebugLocation(Loc),
1944-
ContinuationTy));
1947+
ContinuationType,
1948+
ResumeType,
1949+
Throws));
19451950
}
19461951

19471952
GetAsyncContinuationAddrInst *createGetAsyncContinuationAddr(SILLocation Loc,
19481953
SILValue Operand,
1949-
SILType ContinuationTy) {
1954+
CanType ResumeType,
1955+
bool Throws) {
1956+
auto ContinuationType = SILType::getPrimitiveObjectType(
1957+
getASTContext().TheRawUnsafeContinuationType);
19501958
return insert(new (getModule()) GetAsyncContinuationAddrInst(getSILDebugLocation(Loc),
19511959
Operand,
1952-
ContinuationTy));
1960+
ContinuationType,
1961+
ResumeType,
1962+
Throws));
19531963
}
19541964

19551965
HopToExecutorInst *createHopToExecutor(SILLocation Loc, SILValue Actor) {

include/swift/SIL/SILCloner.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,7 +2961,8 @@ ::visitGetAsyncContinuationInst(GetAsyncContinuationInst *Inst) {
29612961
recordClonedInstruction(Inst,
29622962
getBuilder().createGetAsyncContinuation(
29632963
getOpLocation(Inst->getLoc()),
2964-
getOpType(Inst->getType())));
2964+
getOpASTType(Inst->getFormalResumeType()),
2965+
Inst->throws()));
29652966
}
29662967

29672968
template <typename ImplClass>
@@ -2972,7 +2973,8 @@ ::visitGetAsyncContinuationAddrInst(GetAsyncContinuationAddrInst *Inst) {
29722973
getBuilder().createGetAsyncContinuationAddr(
29732974
getOpLocation(Inst->getLoc()),
29742975
getOpValue(Inst->getOperand()),
2975-
getOpType(Inst->getType())));
2976+
getOpASTType(Inst->getFormalResumeType()),
2977+
Inst->throws()));
29762978
}
29772979

29782980
template <typename ImplClass>

include/swift/SIL/SILInstruction.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3118,15 +3118,22 @@ class GetAsyncContinuationInstBase
31183118
: public SingleValueInstruction
31193119
{
31203120
protected:
3121-
using SingleValueInstruction::SingleValueInstruction;
3122-
3121+
CanType ResumeType;
3122+
bool Throws;
3123+
3124+
GetAsyncContinuationInstBase(SILInstructionKind Kind, SILDebugLocation Loc,
3125+
SILType ContinuationType, CanType ResumeType,
3126+
bool Throws)
3127+
: SingleValueInstruction(Kind, Loc, ContinuationType),
3128+
ResumeType(ResumeType), Throws(Throws) {}
3129+
31233130
public:
31243131
/// Get the type of the value the async task receives on a resume.
3125-
CanType getFormalResumeType() const;
3132+
CanType getFormalResumeType() const { return ResumeType; }
31263133
SILType getLoweredResumeType() const;
31273134

31283135
/// True if the continuation can be used to resume the task by throwing an error.
3129-
bool throws() const;
3136+
bool throws() const { return Throws; }
31303137

31313138
static bool classof(const SILNode *I) {
31323139
return I->getKind() >= SILNodeKind::First_GetAsyncContinuationInstBase &&
@@ -3142,8 +3149,9 @@ class GetAsyncContinuationInst final
31423149
friend SILBuilder;
31433150

31443151
GetAsyncContinuationInst(SILDebugLocation Loc,
3145-
SILType ContinuationTy)
3146-
: InstructionBase(Loc, ContinuationTy)
3152+
SILType ContinuationType, CanType ResumeType,
3153+
bool Throws)
3154+
: InstructionBase(Loc, ContinuationType, ResumeType, Throws)
31473155
{}
31483156

31493157
public:
@@ -3163,9 +3171,10 @@ class GetAsyncContinuationAddrInst final
31633171
{
31643172
friend SILBuilder;
31653173
GetAsyncContinuationAddrInst(SILDebugLocation Loc,
3166-
SILValue Operand,
3167-
SILType ContinuationTy)
3168-
: UnaryInstructionBase(Loc, Operand, ContinuationTy)
3174+
SILValue ResumeBuf,
3175+
SILType ContinuationType, CanType ResumeType,
3176+
bool Throws)
3177+
: UnaryInstructionBase(Loc, ResumeBuf, ContinuationType, ResumeType, Throws)
31693178
{}
31703179
};
31713180

lib/IRGen/IRGenFunction.cpp

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -513,38 +513,17 @@ llvm::Value *IRGenFunction::alignUpToMaximumAlignment(llvm::Type *sizeTy, llvm::
513513
return Builder.CreateAnd(Builder.CreateAdd(val, alignMask), invertedMask);
514514
}
515515

516-
/// Returns the current task \p currTask as an UnsafeContinuation at +1.
516+
/// Returns the current task \p currTask as a Builtin.RawUnsafeContinuation at +1.
517517
static llvm::Value *unsafeContinuationFromTask(IRGenFunction &IGF,
518-
SILType unsafeContinuationTy,
519518
llvm::Value *currTask) {
520519
auto &IGM = IGF.IGM;
521520
auto &Builder = IGF.Builder;
522521

523-
auto &rawPonterTI = IGM.getRawPointerTypeInfo();
524-
auto object =
525-
Builder.CreateBitOrPointerCast(currTask, rawPonterTI.getStorageType());
526-
527-
// Wrap the native object in the UnsafeContinuation struct.
528-
// struct UnsafeContinuation<T> {
529-
// let _continuation : Builtin.RawPointer
530-
// }
531-
auto &unsafeContinuationTI =
532-
cast<LoadableTypeInfo>(IGF.getTypeInfo(unsafeContinuationTy));
533-
auto unsafeContinuationStructTy =
534-
cast<llvm::StructType>(unsafeContinuationTI.getStorageType());
535-
auto fieldTy =
536-
cast<llvm::StructType>(unsafeContinuationStructTy->getElementType(0));
537-
auto reference =
538-
Builder.CreateBitOrPointerCast(object, fieldTy->getElementType(0));
539-
auto field =
540-
Builder.CreateInsertValue(llvm::UndefValue::get(fieldTy), reference, 0);
541-
auto unsafeContinuation = Builder.CreateInsertValue(
542-
llvm::UndefValue::get(unsafeContinuationStructTy), field, 0);
543-
544-
return unsafeContinuation;
545-
}
546-
547-
void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy,
522+
auto &rawPointerTI = IGM.getRawUnsafeContinuationTypeInfo();
523+
return Builder.CreateBitOrPointerCast(currTask, rawPointerTI.getStorageType());
524+
}
525+
526+
void IRGenFunction::emitGetAsyncContinuation(SILType resumeTy,
548527
StackAddress resultAddr,
549528
Explosion &out) {
550529
// Create the continuation.
@@ -574,10 +553,9 @@ void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy,
574553
// continuation_context.resumeExecutor = .. // current executor
575554

576555
auto currTask = getAsyncTask();
577-
auto unsafeContinuation =
578-
unsafeContinuationFromTask(*this, unsafeContinuationTy, currTask);
556+
auto unsafeContinuation = unsafeContinuationFromTask(*this, currTask);
579557

580-
// Create and setup the continuation context for UnsafeContinuation<T>.
558+
// Create and setup the continuation context.
581559
// continuation_context.resumeCtxt = currCtxt;
582560
// continuation_context.errResult = nulllptr;
583561
// continuation_context.result = ... // local alloca T
@@ -599,18 +577,9 @@ void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy,
599577
auto contResultAddr =
600578
Builder.CreateStructGEP(continuationContext.getAddress(), 3);
601579
if (!resultAddr.getAddress().isValid()) {
602-
assert(unsafeContinuationTy.getASTType()
603-
->castTo<BoundGenericType>()
604-
->getGenericArgs()
605-
.size() == 1 &&
606-
"expect UnsafeContinuation<T> to have one generic arg");
607-
auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType()
608-
->castTo<BoundGenericType>()
609-
->getGenericArgs()[0]
610-
->getCanonicalType());
611-
auto &resultTI = getTypeInfo(resultTy);
580+
auto &resumeTI = getTypeInfo(resumeTy);
612581
auto resultAddr =
613-
resultTI.allocateStack(*this, resultTy, "async.continuation.result");
582+
resumeTI.allocateStack(*this, resumeTy, "async.continuation.result");
614583
Builder.CreateStore(Builder.CreateBitOrPointerCast(
615584
resultAddr.getAddress().getAddress(),
616585
contResultAddr->getType()->getPointerElementType()),
@@ -665,7 +634,7 @@ void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy,
665634
}
666635

667636
void IRGenFunction::emitAwaitAsyncContinuation(
668-
SILType unsafeContinuationTy, bool isIndirectResult,
637+
SILType resumeTy, bool isIndirectResult,
669638
Explosion &outDirectResult, llvm::BasicBlock *&normalBB,
670639
llvm::PHINode *&optionalErrorResult, llvm::BasicBlock *&optionalErrorBB) {
671640
assert(AsyncCoroutineCurrentContinuationContext && "no active continuation");
@@ -756,17 +725,13 @@ void IRGenFunction::emitAwaitAsyncContinuation(
756725
auto resultAddrVal =
757726
Builder.CreateLoad(Address(contResultAddrAddr, pointerAlignment));
758727
// Take the result.
759-
auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType()
760-
->castTo<BoundGenericType>()
761-
->getGenericArgs()[0]
762-
->getCanonicalType());
763-
auto &resultTI = cast<LoadableTypeInfo>(getTypeInfo(resultTy));
764-
auto resultStorageTy = resultTI.getStorageType();
728+
auto &resumeTI = cast<LoadableTypeInfo>(getTypeInfo(resumeTy));
729+
auto resultStorageTy = resumeTI.getStorageType();
765730
auto resultAddr =
766731
Address(Builder.CreateBitOrPointerCast(resultAddrVal,
767732
resultStorageTy->getPointerTo()),
768-
resultTI.getFixedAlignment());
769-
resultTI.loadAsTake(*this, resultAddr, outDirectResult);
733+
resumeTI.getFixedAlignment());
734+
resumeTI.loadAsTake(*this, resultAddr, outDirectResult);
770735
}
771736
Builder.CreateBr(normalBB);
772737
AsyncCoroutineCurrentResume = nullptr;

lib/IRGen/IRGenFunction.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,11 @@ class IRGenFunction {
140140
llvm::Function *createAsyncDispatchFn(const FunctionPointer &fnPtr,
141141
ArrayRef<llvm::Type *> argTypes);
142142

143-
void emitGetAsyncContinuation(SILType silTy, StackAddress optionalResultAddr,
143+
void emitGetAsyncContinuation(SILType resumeTy,
144+
StackAddress optionalResultAddr,
144145
Explosion &out);
145146

146-
void emitAwaitAsyncContinuation(SILType unsafeContinuationTy,
147+
void emitAwaitAsyncContinuation(SILType resumeTy,
147148
bool isIndirectResult,
148149
Explosion &outDirectResult,
149150
llvm::BasicBlock *&normalBB,

lib/IRGen/IRGenSIL.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6372,25 +6372,27 @@ void IRGenModule::emitSILStaticInitializers() {
63726372
void IRGenSILFunction::visitGetAsyncContinuationInst(
63736373
GetAsyncContinuationInst *i) {
63746374
Explosion out;
6375-
emitGetAsyncContinuation(i->getType(), StackAddress(), out);
6375+
emitGetAsyncContinuation(i->getLoweredResumeType(), StackAddress(), out);
63766376
setLoweredExplosion(i, out);
63776377
}
63786378

63796379
void IRGenSILFunction::visitGetAsyncContinuationAddrInst(
63806380
GetAsyncContinuationAddrInst *i) {
63816381
auto resultAddr = getLoweredStackAddress(i->getOperand());
63826382
Explosion out;
6383-
emitGetAsyncContinuation(i->getType(), resultAddr, out);
6383+
emitGetAsyncContinuation(i->getLoweredResumeType(), resultAddr, out);
63846384
setLoweredExplosion(i, out);
63856385
}
63866386

63876387
void IRGenSILFunction::visitAwaitAsyncContinuationInst(
63886388
AwaitAsyncContinuationInst *i) {
63896389
Explosion resumeResult;
63906390

6391-
auto continuationTy = i->getOperand()->getType();
6392-
63936391
bool isIndirect = i->getResumeBB()->args_empty();
6392+
SILType resumeTy;
6393+
if (!isIndirect)
6394+
resumeTy = (*i->getResumeBB()->args_begin())->getType();
6395+
63946396
auto &normalDest = getLoweredBB(i->getResumeBB());
63956397
auto *normalDestBB = normalDest.bb;
63966398

@@ -6400,7 +6402,7 @@ void IRGenSILFunction::visitAwaitAsyncContinuationInst(
64006402
assert(!hasError || getLoweredBB(i->getErrorBB()).phis.size() == 1 &&
64016403
"error basic block should only expect one value");
64026404

6403-
emitAwaitAsyncContinuation(continuationTy, isIndirect, resumeResult,
6405+
emitAwaitAsyncContinuation(resumeTy, isIndirect, resumeResult,
64046406
normalDestBB, errorPhi, errorDestBB);
64056407
if (!isIndirect) {
64066408
unsigned firstIndex = 0;

lib/SIL/IR/SILInstructions.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,11 +2844,6 @@ DestructureTupleInst *DestructureTupleInst::create(const SILFunction &F,
28442844
DestructureTupleInst(M, Loc, Operand, Types, OwnershipKinds);
28452845
}
28462846

2847-
CanType GetAsyncContinuationInstBase::getFormalResumeType() const {
2848-
// The resume type is the type argument to the continuation type.
2849-
return getType().castTo<BoundGenericType>().getGenericArgs()[0];
2850-
}
2851-
28522847
SILType GetAsyncContinuationInstBase::getLoweredResumeType() const {
28532848
// The lowered resume type is the maximally-abstracted lowering of the
28542849
// formal resume type.
@@ -2858,12 +2853,6 @@ SILType GetAsyncContinuationInstBase::getLoweredResumeType() const {
28582853
return M.Types.getLoweredType(AbstractionPattern::getOpaque(), formalType, c);
28592854
}
28602855

2861-
bool GetAsyncContinuationInstBase::throws() const {
2862-
// The continuation throws if it's an UnsafeThrowingContinuation
2863-
return getType().castTo<BoundGenericType>()->getDecl()
2864-
== getFunction()->getASTContext().getUnsafeThrowingContinuationDecl();
2865-
}
2866-
28672856
ReturnInst::ReturnInst(SILFunction &func, SILDebugLocation debugLoc,
28682857
SILValue returnValue)
28692858
: UnaryInstructionBase(debugLoc, returnValue),

lib/SIL/IR/SILPrinter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,13 +2077,13 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
20772077
void visitGetAsyncContinuationInst(GetAsyncContinuationInst *GI) {
20782078
if (GI->throws())
20792079
*this << "[throws] ";
2080-
*this << '$' << GI->getFormalResumeType();
2080+
*this << GI->getFormalResumeType();
20812081
}
20822082

20832083
void visitGetAsyncContinuationAddrInst(GetAsyncContinuationAddrInst *GI) {
20842084
if (GI->throws())
20852085
*this << "[throws] ";
2086-
*this << '$' << GI->getFormalResumeType()
2086+
*this << GI->getFormalResumeType()
20872087
<< ", " << getIDAndType(GI->getOperand());
20882088
}
20892089

lib/SIL/Parser/ParseSIL.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5288,8 +5288,8 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
52885288
throws = true;
52895289
}
52905290

5291-
SILType resumeTy;
5292-
if (parseSILType(resumeTy)) {
5291+
CanType resumeTy;
5292+
if (parseASTType(resumeTy)) {
52935293
return true;
52945294
}
52955295

@@ -5304,21 +5304,11 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
53045304
if (parseSILDebugLocation(InstLoc, B))
53055305
return true;
53065306

5307-
auto &M = B.getModule();
5308-
NominalTypeDecl *continuationDecl = throws
5309-
? M.getASTContext().getUnsafeThrowingContinuationDecl()
5310-
: M.getASTContext().getUnsafeContinuationDecl();
5311-
5312-
auto continuationTy = BoundGenericType::get(continuationDecl, Type(),
5313-
resumeTy.getASTType());
5314-
auto continuationSILTy
5315-
= SILType::getPrimitiveObjectType(continuationTy->getCanonicalType());
5316-
53175307
if (Opcode == SILInstructionKind::GetAsyncContinuationAddrInst) {
53185308
ResultVal = B.createGetAsyncContinuationAddr(InstLoc, resumeBuffer,
5319-
continuationSILTy);
5309+
resumeTy, throws);
53205310
} else {
5321-
ResultVal = B.createGetAsyncContinuation(InstLoc, continuationSILTy);
5311+
ResultVal = B.createGetAsyncContinuation(InstLoc, resumeTy, throws);
53225312
}
53235313
break;
53245314
}

lib/SIL/Verifier/SILVerifier.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4892,12 +4892,7 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
48924892

48934893
void checkGetAsyncContinuationInstBase(GetAsyncContinuationInstBase *GACI) {
48944894
auto resultTy = GACI->getType();
4895-
auto &C = resultTy.getASTContext();
4896-
auto resultBGT = resultTy.getAs<BoundGenericType>();
4897-
require(resultBGT, "Instruction type must be a continuation type");
4898-
auto resultDecl = resultBGT->getDecl();
4899-
require(resultDecl == C.getUnsafeContinuationDecl()
4900-
|| resultDecl == C.getUnsafeThrowingContinuationDecl(),
4895+
require(resultTy.is<BuiltinRawUnsafeContinuationType>(),
49014896
"Instruction type must be a continuation type");
49024897
}
49034898

0 commit comments

Comments
 (0)