@@ -1039,9 +1039,11 @@ class AsyncPartialApplicationForwarderEmission
1039
1039
: public PartialApplicationForwarderEmission {
1040
1040
using super = PartialApplicationForwarderEmission;
1041
1041
AsyncContextLayout layout;
1042
- llvm::Value *contextBuffer;
1042
+ llvm::Value *calleeFunction;
1043
+ llvm::Value *currentResumeFn;
1043
1044
Size contextSize;
1044
1045
Address context;
1046
+ Address calleeContextBuffer;
1045
1047
unsigned currentArgumentIndex;
1046
1048
struct DynamicFunction {
1047
1049
using Kind = DynamicFunctionKind;
@@ -1060,23 +1062,11 @@ class AsyncPartialApplicationForwarderEmission
1060
1062
};
1061
1063
Optional<Self> self = llvm::None;
1062
1064
1063
- llvm::Value *loadValue (ElementLayout layout) {
1064
- Address addr = layout.project (subIGF, context, /* offsets*/ llvm::None);
1065
- auto &ti = cast<LoadableTypeInfo>(layout.getType ());
1066
- Explosion explosion;
1067
- ti.loadAsTake (subIGF, addr, explosion);
1068
- return explosion.claimNext ();
1069
- }
1070
1065
void saveValue (ElementLayout layout, Explosion &explosion) {
1071
1066
Address addr = layout.project (subIGF, context, /* offsets*/ llvm::None);
1072
1067
auto &ti = cast<LoadableTypeInfo>(layout.getType ());
1073
1068
ti.initialize (subIGF, explosion, addr, /* isOutlined*/ false );
1074
1069
}
1075
- void loadValue (ElementLayout layout, Explosion &explosion) {
1076
- Address addr = layout.project (subIGF, context, /* offsets*/ llvm::None);
1077
- auto &ti = cast<LoadableTypeInfo>(layout.getType ());
1078
- ti.loadAsTake (subIGF, addr, explosion);
1079
- }
1080
1070
1081
1071
public:
1082
1072
AsyncPartialApplicationForwarderEmission (
@@ -1099,9 +1089,59 @@ class AsyncPartialApplicationForwarderEmission
1099
1089
void begin () override { super::begin (); }
1100
1090
1101
1091
void mapAsyncParameters () override {
1102
- contextBuffer = origParams.claimNext ();
1103
- context = layout.emitCastTo (subIGF, contextBuffer);
1104
- args.add (contextBuffer);
1092
+ // Ignore the original context.
1093
+ (void )origParams.claimNext ();
1094
+
1095
+ llvm::Value *dynamicContextSize32;
1096
+ auto initialContextSize = Size (0 );
1097
+ std::tie (calleeFunction, dynamicContextSize32) = getAsyncFunctionAndSize (
1098
+ subIGF, origType->getRepresentation (), *staticFnPtr,
1099
+ nullptr , std::make_pair (true , true ), initialContextSize);
1100
+ auto *dynamicContextSize =
1101
+ subIGF.Builder .CreateZExt (dynamicContextSize32, subIGF.IGM .SizeTy );
1102
+ calleeContextBuffer =
1103
+ emitAllocAsyncContext (subIGF, dynamicContextSize);
1104
+ context = layout.emitCastTo (subIGF, calleeContextBuffer.getAddress ());
1105
+ auto calleeContext =
1106
+ layout.emitCastTo (subIGF, calleeContextBuffer.getAddress ());
1107
+ args.add (subIGF.Builder .CreateBitOrPointerCast (
1108
+ calleeContextBuffer.getAddress (), IGM.SwiftContextPtrTy ));
1109
+
1110
+ // Set caller info into the context.
1111
+ { // caller context
1112
+ Explosion explosion;
1113
+ auto fieldLayout = layout.getParentLayout ();
1114
+ auto *context = subIGF.getAsyncContext ();
1115
+ if (auto schema =
1116
+ subIGF.IGM .getOptions ().PointerAuth .AsyncContextParent ) {
1117
+ Address fieldAddr =
1118
+ fieldLayout.project (subIGF, calleeContext, /* offsets*/ llvm::None);
1119
+ auto authInfo = PointerAuthInfo::emit (
1120
+ subIGF, schema, fieldAddr.getAddress (), PointerAuthEntity ());
1121
+ context = emitPointerAuthSign (subIGF, context, authInfo);
1122
+ }
1123
+ explosion.add (context);
1124
+ saveValue (fieldLayout, explosion);
1125
+ }
1126
+ { // Return to caller function.
1127
+ auto fieldLayout = layout.getResumeParentLayout ();
1128
+ currentResumeFn = subIGF.Builder .CreateIntrinsicCall (
1129
+ llvm::Intrinsic::coro_async_resume, {});
1130
+ auto fnVal = currentResumeFn;
1131
+ // Sign the pointer.
1132
+ if (auto schema = subIGF.IGM .getOptions ().PointerAuth .AsyncContextResume ) {
1133
+ Address fieldAddr =
1134
+ fieldLayout.project (subIGF, calleeContext, /* offsets*/ llvm::None);
1135
+ auto authInfo = PointerAuthInfo::emit (
1136
+ subIGF, schema, fieldAddr.getAddress (), PointerAuthEntity ());
1137
+ fnVal = emitPointerAuthSign (subIGF, fnVal, authInfo);
1138
+ }
1139
+ fnVal = subIGF.Builder .CreateBitCast (
1140
+ fnVal, subIGF.IGM .TaskContinuationFunctionPtrTy );
1141
+ Explosion explosion;
1142
+ explosion.add (fnVal);
1143
+ saveValue (fieldLayout, explosion);
1144
+ }
1105
1145
}
1106
1146
void gatherArgumentsFromApply () override {
1107
1147
super::gatherArgumentsFromApply (true );
@@ -1127,13 +1167,87 @@ class AsyncPartialApplicationForwarderEmission
1127
1167
// Nothing to do here. The error result pointer is already in the
1128
1168
// appropriate position.
1129
1169
}
1170
+ FunctionPointer getFunctionPointerForDispatchCall (const FunctionPointer &fn) {
1171
+ auto &IGM = subIGF.IGM ;
1172
+ // Strip off the return type. The original function pointer signature
1173
+ // captured both the entry point type and the resume function type.
1174
+ auto *fnTy = llvm::FunctionType::get (
1175
+ IGM.VoidTy , fn.getSignature ().getType ()->params (), false /* vaargs*/ );
1176
+ auto signature =
1177
+ Signature (fnTy, fn.getSignature ().getAttributes (), IGM.SwiftAsyncCC );
1178
+ auto fnPtr =
1179
+ FunctionPointer (FunctionPointer::Kind::Function, fn.getRawPointer (),
1180
+ fn.getAuthInfo (), signature);
1181
+ return fnPtr;
1182
+ }
1130
1183
llvm::CallInst *createCall (FunctionPointer &fnPtr) override {
1131
- return subIGF.Builder .CreateCall (fnPtr.getAsFunction (subIGF),
1132
- args.claimAll ());
1184
+ auto newFnPtr = FunctionPointer (
1185
+ FunctionPointer::Kind::Function, fnPtr.getPointer (subIGF),
1186
+ fnPtr.getAuthInfo (), Signature::forAsyncAwait (subIGF.IGM , origType));
1187
+ auto &Builder = subIGF.Builder ;
1188
+
1189
+ auto argValues = args.claimAll ();
1190
+
1191
+ // Setup the suspend point.
1192
+ SmallVector<llvm::Value *, 8 > arguments;
1193
+ auto signature = newFnPtr.getSignature ();
1194
+ auto asyncContextIndex = signature.getAsyncContextIndex ();
1195
+ auto paramAttributeFlags =
1196
+ asyncContextIndex |
1197
+ (signature.getAsyncResumeFunctionSwiftSelfIndex () << 8 );
1198
+ // Index of swiftasync context | ((index of swiftself) << 8).
1199
+ arguments.push_back (
1200
+ IGM.getInt32 (paramAttributeFlags));
1201
+ arguments.push_back (currentResumeFn);
1202
+ auto resumeProjFn = subIGF.getOrCreateResumePrjFn ();
1203
+ arguments.push_back (
1204
+ Builder.CreateBitOrPointerCast (resumeProjFn, IGM.Int8PtrTy ));
1205
+ auto dispatchFn = subIGF.createAsyncDispatchFn (
1206
+ getFunctionPointerForDispatchCall (newFnPtr), argValues);
1207
+ arguments.push_back (
1208
+ Builder.CreateBitOrPointerCast (dispatchFn, IGM.Int8PtrTy ));
1209
+ arguments.push_back (
1210
+ Builder.CreateBitOrPointerCast (newFnPtr.getRawPointer (), IGM.Int8PtrTy ));
1211
+ if (auto authInfo = newFnPtr.getAuthInfo ()) {
1212
+ arguments.push_back (newFnPtr.getAuthInfo ().getDiscriminator ());
1213
+ }
1214
+ for (auto arg : argValues)
1215
+ arguments.push_back (arg);
1216
+ auto resultTy =
1217
+ cast<llvm::StructType>(signature.getType ()->getReturnType ());
1218
+ return subIGF.emitSuspendAsyncCall (asyncContextIndex, resultTy, arguments);
1133
1219
}
1134
1220
void createReturn (llvm::CallInst *call) override {
1135
- call->setTailCallKind (IGM.AsyncTailCallKind );
1136
- subIGF.Builder .CreateRetVoid ();
1221
+ emitDeallocAsyncContext (subIGF, calleeContextBuffer);
1222
+ auto numAsyncContextParams =
1223
+ Signature::forAsyncReturn (IGM, outType).getAsyncContextIndex () + 1 ;
1224
+ llvm::Value *result = call;
1225
+ auto *suspendResultTy = cast<llvm::StructType>(result->getType ());
1226
+ Explosion resultExplosion;
1227
+ Explosion errorExplosion;
1228
+ SILFunctionConventions conv (outType, subIGF.getSILModule ());
1229
+ auto hasError = outType->hasErrorResult ();
1230
+
1231
+ Optional<ArrayRef<llvm::Value *>> nativeResults = llvm::None;
1232
+ SmallVector<llvm::Value *, 16 > nativeResultsStorage;
1233
+
1234
+ if (suspendResultTy->getNumElements () == numAsyncContextParams) {
1235
+ // no result to forward.
1236
+ assert (!hasError);
1237
+ } else {
1238
+ auto &Builder = subIGF.Builder ;
1239
+ auto resultTys =
1240
+ makeArrayRef (suspendResultTy->element_begin () + numAsyncContextParams,
1241
+ suspendResultTy->element_end ());
1242
+
1243
+ for (unsigned i = 0 , e = resultTys.size (); i != e; ++i) {
1244
+ llvm::Value *elt =
1245
+ Builder.CreateExtractValue (result, numAsyncContextParams + i);
1246
+ nativeResultsStorage.push_back (elt);
1247
+ }
1248
+ nativeResults = nativeResultsStorage;
1249
+ }
1250
+ emitAsyncReturn (subIGF, layout, origType, nativeResults);
1137
1251
}
1138
1252
void end () override {
1139
1253
assert (context.isValid ());
@@ -1180,6 +1294,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
1180
1294
llvm::AttributeList outAttrs = outSig.getAttributes ();
1181
1295
llvm::FunctionType *fwdTy = outSig.getType ();
1182
1296
SILFunctionConventions outConv (outType, IGM.getSILModule ());
1297
+ Optional<AsyncContextLayout> asyncLayout;
1183
1298
1184
1299
StringRef FnName;
1185
1300
if (staticFnPtr)
@@ -1203,19 +1318,29 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
1203
1318
1204
1319
IRGenFunction subIGF (IGM, fwd);
1205
1320
if (origType->isAsync ()) {
1206
- subIGF.setupAsync (
1207
- Signature::forAsyncEntry (IGM, outType).getAsyncContextIndex ());
1321
+ auto asyncContextIdx =
1322
+ Signature::forAsyncEntry (IGM, outType).getAsyncContextIndex ();
1323
+ asyncLayout.emplace (irgen::getAsyncContextLayout (
1324
+ IGM, origType, substType, subs, /* suppress generics*/ false ,
1325
+ FunctionPointer::Kind (
1326
+ FunctionPointer::BasicKind::AsyncFunctionPointer)));
1327
+
1328
+ subIGF.setupAsync (asyncContextIdx);
1208
1329
1209
- auto *calleeAFP = staticFnPtr->getDirectPointer ();
1330
+ // auto *calleeAFP = staticFnPtr->getDirectPointer();
1210
1331
LinkEntity entity = LinkEntity::forPartialApplyForwarder (fwd);
1211
- auto size = Size (0 );
1212
1332
assert (!asyncFunctionPtr &&
1213
1333
" already had an async function pointer to the forwarder?!" );
1214
- asyncFunctionPtr = emitAsyncFunctionPointer (IGM, fwd, entity, size);
1334
+ emitAsyncFunctionEntry (subIGF, *asyncLayout, entity, asyncContextIdx);
1335
+ asyncFunctionPtr =
1336
+ emitAsyncFunctionPointer (IGM, fwd, entity, asyncLayout->getSize ());
1337
+ // TODO: if calleeAFP is definition:
1338
+ #if 0
1215
1339
subIGF.Builder.CreateIntrinsicCall(
1216
1340
llvm::Intrinsic::coro_async_size_replace,
1217
1341
{subIGF.Builder.CreateBitCast(asyncFunctionPtr, IGM.Int8PtrTy),
1218
1342
subIGF.Builder.CreateBitCast(calleeAFP, IGM.Int8PtrTy)});
1343
+ #endif
1219
1344
}
1220
1345
if (IGM.DebugInfo )
1221
1346
IGM.DebugInfo ->emitArtificialFunction (subIGF, fwd);
@@ -1679,7 +1804,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
1679
1804
1680
1805
llvm::CallInst *call = emission->createCall (fnPtr);
1681
1806
1682
- if (addressesToDeallocate.empty () && !needsAllocas &&
1807
+ if (!origType-> isAsync () && addressesToDeallocate.empty () && !needsAllocas &&
1683
1808
(!consumesContext || !dependsOnContextLifetime))
1684
1809
call->setTailCall ();
1685
1810
0 commit comments