Skip to content

Commit 3de7f6f

Browse files
Merge pull request swiftlang#36700 from aschwaighofer/async_partial_apply_forwarder_fix
IRGen: Create full async functions for partial apply forwarders
2 parents 435f0de + 7ae0b1d commit 3de7f6f

File tree

2 files changed

+160
-26
lines changed

2 files changed

+160
-26
lines changed

lib/IRGen/GenFunc.cpp

Lines changed: 151 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,9 +1039,11 @@ class AsyncPartialApplicationForwarderEmission
10391039
: public PartialApplicationForwarderEmission {
10401040
using super = PartialApplicationForwarderEmission;
10411041
AsyncContextLayout layout;
1042-
llvm::Value *contextBuffer;
1042+
llvm::Value *calleeFunction;
1043+
llvm::Value *currentResumeFn;
10431044
Size contextSize;
10441045
Address context;
1046+
Address calleeContextBuffer;
10451047
unsigned currentArgumentIndex;
10461048
struct DynamicFunction {
10471049
using Kind = DynamicFunctionKind;
@@ -1060,23 +1062,11 @@ class AsyncPartialApplicationForwarderEmission
10601062
};
10611063
Optional<Self> self = llvm::None;
10621064

1063-
llvm::Value *loadValue(ElementLayout layout) {
1064-
Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None);
1065-
auto &ti = cast<LoadableTypeInfo>(layout.getType());
1066-
Explosion explosion;
1067-
ti.loadAsTake(subIGF, addr, explosion);
1068-
return explosion.claimNext();
1069-
}
10701065
void saveValue(ElementLayout layout, Explosion &explosion) {
10711066
Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None);
10721067
auto &ti = cast<LoadableTypeInfo>(layout.getType());
10731068
ti.initialize(subIGF, explosion, addr, /*isOutlined*/ false);
10741069
}
1075-
void loadValue(ElementLayout layout, Explosion &explosion) {
1076-
Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None);
1077-
auto &ti = cast<LoadableTypeInfo>(layout.getType());
1078-
ti.loadAsTake(subIGF, addr, explosion);
1079-
}
10801070

10811071
public:
10821072
AsyncPartialApplicationForwarderEmission(
@@ -1099,9 +1089,59 @@ class AsyncPartialApplicationForwarderEmission
10991089
void begin() override { super::begin(); }
11001090

11011091
void mapAsyncParameters() override {
1102-
contextBuffer = origParams.claimNext();
1103-
context = layout.emitCastTo(subIGF, contextBuffer);
1104-
args.add(contextBuffer);
1092+
// Ignore the original context.
1093+
(void)origParams.claimNext();
1094+
1095+
llvm::Value *dynamicContextSize32;
1096+
auto initialContextSize = Size(0);
1097+
std::tie(calleeFunction, dynamicContextSize32) = getAsyncFunctionAndSize(
1098+
subIGF, origType->getRepresentation(), *staticFnPtr,
1099+
nullptr, std::make_pair(true, true), initialContextSize);
1100+
auto *dynamicContextSize =
1101+
subIGF.Builder.CreateZExt(dynamicContextSize32, subIGF.IGM.SizeTy);
1102+
calleeContextBuffer =
1103+
emitAllocAsyncContext(subIGF, dynamicContextSize);
1104+
context = layout.emitCastTo(subIGF, calleeContextBuffer.getAddress());
1105+
auto calleeContext =
1106+
layout.emitCastTo(subIGF, calleeContextBuffer.getAddress());
1107+
args.add(subIGF.Builder.CreateBitOrPointerCast(
1108+
calleeContextBuffer.getAddress(), IGM.SwiftContextPtrTy));
1109+
1110+
// Set caller info into the context.
1111+
{ // caller context
1112+
Explosion explosion;
1113+
auto fieldLayout = layout.getParentLayout();
1114+
auto *context = subIGF.getAsyncContext();
1115+
if (auto schema =
1116+
subIGF.IGM.getOptions().PointerAuth.AsyncContextParent) {
1117+
Address fieldAddr =
1118+
fieldLayout.project(subIGF, calleeContext, /*offsets*/ llvm::None);
1119+
auto authInfo = PointerAuthInfo::emit(
1120+
subIGF, schema, fieldAddr.getAddress(), PointerAuthEntity());
1121+
context = emitPointerAuthSign(subIGF, context, authInfo);
1122+
}
1123+
explosion.add(context);
1124+
saveValue(fieldLayout, explosion);
1125+
}
1126+
{ // Return to caller function.
1127+
auto fieldLayout = layout.getResumeParentLayout();
1128+
currentResumeFn = subIGF.Builder.CreateIntrinsicCall(
1129+
llvm::Intrinsic::coro_async_resume, {});
1130+
auto fnVal = currentResumeFn;
1131+
// Sign the pointer.
1132+
if (auto schema = subIGF.IGM.getOptions().PointerAuth.AsyncContextResume) {
1133+
Address fieldAddr =
1134+
fieldLayout.project(subIGF, calleeContext, /*offsets*/ llvm::None);
1135+
auto authInfo = PointerAuthInfo::emit(
1136+
subIGF, schema, fieldAddr.getAddress(), PointerAuthEntity());
1137+
fnVal = emitPointerAuthSign(subIGF, fnVal, authInfo);
1138+
}
1139+
fnVal = subIGF.Builder.CreateBitCast(
1140+
fnVal, subIGF.IGM.TaskContinuationFunctionPtrTy);
1141+
Explosion explosion;
1142+
explosion.add(fnVal);
1143+
saveValue(fieldLayout, explosion);
1144+
}
11051145
}
11061146
void gatherArgumentsFromApply() override {
11071147
super::gatherArgumentsFromApply(true);
@@ -1127,13 +1167,87 @@ class AsyncPartialApplicationForwarderEmission
11271167
// Nothing to do here. The error result pointer is already in the
11281168
// appropriate position.
11291169
}
1170+
FunctionPointer getFunctionPointerForDispatchCall(const FunctionPointer &fn) {
1171+
auto &IGM = subIGF.IGM;
1172+
// Strip off the return type. The original function pointer signature
1173+
// captured both the entry point type and the resume function type.
1174+
auto *fnTy = llvm::FunctionType::get(
1175+
IGM.VoidTy, fn.getSignature().getType()->params(), false /*vaargs*/);
1176+
auto signature =
1177+
Signature(fnTy, fn.getSignature().getAttributes(), IGM.SwiftAsyncCC);
1178+
auto fnPtr =
1179+
FunctionPointer(FunctionPointer::Kind::Function, fn.getRawPointer(),
1180+
fn.getAuthInfo(), signature);
1181+
return fnPtr;
1182+
}
11301183
llvm::CallInst *createCall(FunctionPointer &fnPtr) override {
1131-
return subIGF.Builder.CreateCall(fnPtr.getAsFunction(subIGF),
1132-
args.claimAll());
1184+
auto newFnPtr = FunctionPointer(
1185+
FunctionPointer::Kind::Function, fnPtr.getPointer(subIGF),
1186+
fnPtr.getAuthInfo(), Signature::forAsyncAwait(subIGF.IGM, origType));
1187+
auto &Builder = subIGF.Builder;
1188+
1189+
auto argValues = args.claimAll();
1190+
1191+
// Setup the suspend point.
1192+
SmallVector<llvm::Value *, 8> arguments;
1193+
auto signature = newFnPtr.getSignature();
1194+
auto asyncContextIndex = signature.getAsyncContextIndex();
1195+
auto paramAttributeFlags =
1196+
asyncContextIndex |
1197+
(signature.getAsyncResumeFunctionSwiftSelfIndex() << 8);
1198+
// Index of swiftasync context | ((index of swiftself) << 8).
1199+
arguments.push_back(
1200+
IGM.getInt32(paramAttributeFlags));
1201+
arguments.push_back(currentResumeFn);
1202+
auto resumeProjFn = subIGF.getOrCreateResumePrjFn();
1203+
arguments.push_back(
1204+
Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy));
1205+
auto dispatchFn = subIGF.createAsyncDispatchFn(
1206+
getFunctionPointerForDispatchCall(newFnPtr), argValues);
1207+
arguments.push_back(
1208+
Builder.CreateBitOrPointerCast(dispatchFn, IGM.Int8PtrTy));
1209+
arguments.push_back(
1210+
Builder.CreateBitOrPointerCast(newFnPtr.getRawPointer(), IGM.Int8PtrTy));
1211+
if (auto authInfo = newFnPtr.getAuthInfo()) {
1212+
arguments.push_back(newFnPtr.getAuthInfo().getDiscriminator());
1213+
}
1214+
for (auto arg : argValues)
1215+
arguments.push_back(arg);
1216+
auto resultTy =
1217+
cast<llvm::StructType>(signature.getType()->getReturnType());
1218+
return subIGF.emitSuspendAsyncCall(asyncContextIndex, resultTy, arguments);
11331219
}
11341220
void createReturn(llvm::CallInst *call) override {
1135-
call->setTailCallKind(IGM.AsyncTailCallKind);
1136-
subIGF.Builder.CreateRetVoid();
1221+
emitDeallocAsyncContext(subIGF, calleeContextBuffer);
1222+
auto numAsyncContextParams =
1223+
Signature::forAsyncReturn(IGM, outType).getAsyncContextIndex() + 1;
1224+
llvm::Value *result = call;
1225+
auto *suspendResultTy = cast<llvm::StructType>(result->getType());
1226+
Explosion resultExplosion;
1227+
Explosion errorExplosion;
1228+
SILFunctionConventions conv(outType, subIGF.getSILModule());
1229+
auto hasError = outType->hasErrorResult();
1230+
1231+
Optional<ArrayRef<llvm::Value *>> nativeResults = llvm::None;
1232+
SmallVector<llvm::Value *, 16> nativeResultsStorage;
1233+
1234+
if (suspendResultTy->getNumElements() == numAsyncContextParams) {
1235+
// no result to forward.
1236+
assert(!hasError);
1237+
} else {
1238+
auto &Builder = subIGF.Builder;
1239+
auto resultTys =
1240+
makeArrayRef(suspendResultTy->element_begin() + numAsyncContextParams,
1241+
suspendResultTy->element_end());
1242+
1243+
for (unsigned i = 0, e = resultTys.size(); i != e; ++i) {
1244+
llvm::Value *elt =
1245+
Builder.CreateExtractValue(result, numAsyncContextParams + i);
1246+
nativeResultsStorage.push_back(elt);
1247+
}
1248+
nativeResults = nativeResultsStorage;
1249+
}
1250+
emitAsyncReturn(subIGF, layout, origType, nativeResults);
11371251
}
11381252
void end() override {
11391253
assert(context.isValid());
@@ -1180,6 +1294,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
11801294
llvm::AttributeList outAttrs = outSig.getAttributes();
11811295
llvm::FunctionType *fwdTy = outSig.getType();
11821296
SILFunctionConventions outConv(outType, IGM.getSILModule());
1297+
Optional<AsyncContextLayout> asyncLayout;
11831298

11841299
StringRef FnName;
11851300
if (staticFnPtr)
@@ -1203,19 +1318,29 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
12031318

12041319
IRGenFunction subIGF(IGM, fwd);
12051320
if (origType->isAsync()) {
1206-
subIGF.setupAsync(
1207-
Signature::forAsyncEntry(IGM, outType).getAsyncContextIndex());
1321+
auto asyncContextIdx =
1322+
Signature::forAsyncEntry(IGM, outType).getAsyncContextIndex();
1323+
asyncLayout.emplace(irgen::getAsyncContextLayout(
1324+
IGM, origType, substType, subs, /*suppress generics*/ false,
1325+
FunctionPointer::Kind(
1326+
FunctionPointer::BasicKind::AsyncFunctionPointer)));
1327+
1328+
subIGF.setupAsync(asyncContextIdx);
12081329

1209-
auto *calleeAFP = staticFnPtr->getDirectPointer();
1330+
//auto *calleeAFP = staticFnPtr->getDirectPointer();
12101331
LinkEntity entity = LinkEntity::forPartialApplyForwarder(fwd);
1211-
auto size = Size(0);
12121332
assert(!asyncFunctionPtr &&
12131333
"already had an async function pointer to the forwarder?!");
1214-
asyncFunctionPtr = emitAsyncFunctionPointer(IGM, fwd, entity, size);
1334+
emitAsyncFunctionEntry(subIGF, *asyncLayout, entity, asyncContextIdx);
1335+
asyncFunctionPtr =
1336+
emitAsyncFunctionPointer(IGM, fwd, entity, asyncLayout->getSize());
1337+
// TODO: if calleeAFP is definition:
1338+
#if 0
12151339
subIGF.Builder.CreateIntrinsicCall(
12161340
llvm::Intrinsic::coro_async_size_replace,
12171341
{subIGF.Builder.CreateBitCast(asyncFunctionPtr, IGM.Int8PtrTy),
12181342
subIGF.Builder.CreateBitCast(calleeAFP, IGM.Int8PtrTy)});
1343+
#endif
12191344
}
12201345
if (IGM.DebugInfo)
12211346
IGM.DebugInfo->emitArtificialFunction(subIGF, fwd);
@@ -1679,7 +1804,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
16791804

16801805
llvm::CallInst *call = emission->createCall(fnPtr);
16811806

1682-
if (addressesToDeallocate.empty() && !needsAllocas &&
1807+
if (!origType->isAsync() && addressesToDeallocate.empty() && !needsAllocas &&
16831808
(!consumesContext || !dependsOnContextLifetime))
16841809
call->setTailCall();
16851810

test/IRGen/async/partial_apply.sil

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,12 @@ bb0(%thick : $@callee_guaranteed @async @convention(thick) (Int64, Int32) -> Int
506506
// CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s45indirect_guaranteed_captured_class_pair_paramTA.{{[0-9]+}}"(
507507
// CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s12create_pa_f2Tw_"(
508508
// CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s12create_pa_f2Tw0_"(
509+
510+
sil @external_closure : $@convention(thin) @async (Int, Int) -> (Int, @error Error)
511+
512+
sil @dont_crash : $@convention(thin) @async (Int) -> @owned @async @callee_guaranteed (Int) -> (Int, @error Error) {
513+
bb0(%0 : $Int):
514+
%2 = function_ref @external_closure : $@convention(thin) @async (Int, Int) -> (Int, @error Error)
515+
%3 = partial_apply [callee_guaranteed] %2(%0) : $@convention(thin) @async (Int, Int) -> (Int, @error Error)
516+
return %3 : $@async @callee_guaranteed (Int) -> (Int, @error Error)
517+
}

0 commit comments

Comments
 (0)