Skip to content

Commit 844ce09

Browse files
committed
IRGen: Refer to dispatch thunk async function pointers when making calls
Part of rdar://problem/73625623.
1 parent 50032a6 commit 844ce09

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

lib/IRGen/IRGenSIL.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6133,14 +6133,22 @@ void IRGenSILFunction::visitWitnessMethodInst(swift::WitnessMethodInst *i) {
61336133
CanType baseTy = i->getLookupType();
61346134
ProtocolConformanceRef conformance = i->getConformance();
61356135
SILDeclRef member = i->getMember();
6136+
auto fnType = IGM.getSILTypes().getConstantFunctionType(
6137+
IGM.getMaximalTypeExpansionContext(), member);
61366138

61376139
assert(member.requiresNewWitnessTableEntry());
61386140

61396141
if (IGM.isResilient(conformance.getRequirement(),
61406142
ResilienceExpansion::Maximal)) {
6141-
auto *fnPtr = IGM.getAddrOfDispatchThunk(member, NotForDefinition);
6142-
auto fnType = IGM.getSILTypes().getConstantFunctionType(
6143-
IGM.getMaximalTypeExpansionContext(), member);
6143+
llvm::Constant *fnPtr = IGM.getAddrOfDispatchThunk(member, NotForDefinition);
6144+
6145+
if (fnType->isAsync()) {
6146+
auto *fnPtrType = fnPtr->getType();
6147+
fnPtr = IGM.getAddrOfAsyncFunctionPointer(
6148+
LinkEntity::forDispatchThunk(member));
6149+
fnPtr = llvm::ConstantExpr::getBitCast(fnPtr, fnPtrType);
6150+
}
6151+
61446152
auto sig = IGM.getSignature(fnType);
61456153
auto fn = FunctionPointer::forDirect(fnType, fnPtr, sig);
61466154

@@ -6340,7 +6348,15 @@ void IRGenSILFunction::visitClassMethodInst(swift::ClassMethodInst *i) {
63406348
auto *classDecl = cast<ClassDecl>(method.getDecl()->getDeclContext());
63416349
if (IGM.hasResilientMetadata(classDecl,
63426350
ResilienceExpansion::Maximal)) {
6343-
auto *fnPtr = IGM.getAddrOfDispatchThunk(method, NotForDefinition);
6351+
llvm::Constant *fnPtr = IGM.getAddrOfDispatchThunk(method, NotForDefinition);
6352+
6353+
if (methodType->isAsync()) {
6354+
auto *fnPtrType = fnPtr->getType();
6355+
fnPtr = IGM.getAddrOfAsyncFunctionPointer(
6356+
LinkEntity::forDispatchThunk(method));
6357+
fnPtr = llvm::ConstantExpr::getBitCast(fnPtr, fnPtrType);
6358+
}
6359+
63446360
auto sig = IGM.getSignature(methodType);
63456361
FunctionPointer fn(methodType, fnPtr, sig);
63466362

test/IRGen/async/class_resilience.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,11 @@ open class MyBaseClass<T> {
2727

2828
// CHECK-LABEL: @"$s16class_resilience11MyBaseClassC4waitxyYFTjTu" = {{(dllexport )?}}{{(protected )?}}global %swift.async_func_pointer
2929

30+
// CHECK-LABEL: define {{(dllexport )?}}{{(protected )?}}swiftcc void @"$s16class_resilience14callsAwaitableyx010resilient_A09BaseClassCyxGYlF"(%swift.task* %0, %swift.executor* %1, %swift.context* swiftasync %2)
31+
// CHECK: %swift.async_func_pointer* @"$s15resilient_class9BaseClassC4waitxyYFTjTu"
32+
// CHECK: ret void
33+
public func callsAwaitable<T>(_ c: BaseClass<T>) async -> T {
34+
return await c.wait()
35+
}
36+
3037
// CHECK-LABEL: define {{(dllexport )?}}{{(protected )?}}swiftcc void @"$s16class_resilience11MyBaseClassC4waitxyYFTj"(%swift.task* %0, %swift.executor* %1, %swift.context* swiftasync %2) #0 {

test/IRGen/async/protocol_resilience.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,11 @@ public protocol MyAwaitable {
1818

1919
// CHECK-LABEL: @"$s19protocol_resilience11MyAwaitableP4wait6ResultQzyYFTjTu" = {{(dllexport )?}}{{(protected )?}}global %swift.async_func_pointer
2020

21+
// CHECK-LABEL: define {{(dllexport )?}}{{(protected )?}}swiftcc void @"$s19protocol_resilience14callsAwaitabley6ResultQzxY010resilient_A00D0RzlF"(%swift.task* %0, %swift.executor* %1, %swift.context* swiftasync %2)
22+
// CHECK: %swift.async_func_pointer* @"$s18resilient_protocol9AwaitableP4wait6ResultQzyYFTjTu"
23+
// CHECK: ret void
24+
public func callsAwaitable<T : Awaitable>(_ t: T) async -> T.Result {
25+
return await t.wait()
26+
}
27+
2128
// CHECK-LABEL: define {{(dllexport )?}}{{(protected )?}}swiftcc void @"$s19protocol_resilience11MyAwaitableP4wait6ResultQzyYFTj"(%swift.task* %0, %swift.executor* %1, %swift.context* swiftasync %2)

0 commit comments

Comments
 (0)