Skip to content

Commit 11723cc

Browse files
committed
[Async CC] Never map to native explosions.
Previously, when lowering the entry point of an async function, the parameters were lowered to explosions that matched those of sync functions, namely native explosions. That is incorrect for async functions where the structured values are within the async context. Here, that error is fixed, by adding a new customization point to NativeCCEntryPointArgumentEmission which behaves as before for sync functions but which simply extracts an argument from the async context for async functions. rdar://problem/71260972
1 parent 3a0ffbb commit 11723cc

File tree

4 files changed

+165
-35
lines changed

4 files changed

+165
-35
lines changed

lib/IRGen/EntryPointArgumentEmission.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ class Value;
1717
}
1818

1919
namespace swift {
20+
21+
class SILArgument;
22+
2023
namespace irgen {
2124

2225
class Explosion;
2326
struct GenericRequirement;
27+
class LoadableTypeInfo;
28+
class TypeInfo;
2429

2530
class EntryPointArgumentEmission {
2631

@@ -44,6 +49,12 @@ class NativeCCEntryPointArgumentEmission
4449
virtual llvm::Value *getSelfWitnessTable() = 0;
4550
virtual llvm::Value *getSelfMetadata() = 0;
4651
virtual llvm::Value *getCoroutineBuffer() = 0;
52+
virtual Explosion
53+
explosionForObject(IRGenFunction &IGF, unsigned index, SILArgument *param,
54+
SILType paramTy, const LoadableTypeInfo &loadableParamTI,
55+
const LoadableTypeInfo &loadableArgTI,
56+
std::function<Explosion(unsigned index, unsigned size)>
57+
explosionForArgument) = 0;
4758
};
4859

4960
} // end namespace irgen

lib/IRGen/IRGenSIL.cpp

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,42 @@ class SyncNativeCCEntryPointArgumentEmission final
12621262
llvm::Value *getCoroutineBuffer() override {
12631263
return allParamValues.claimNext();
12641264
}
1265+
Explosion
1266+
explosionForObject(IRGenFunction &IGF, unsigned index, SILArgument *param,
1267+
SILType paramTy, const LoadableTypeInfo &loadableParamTI,
1268+
const LoadableTypeInfo &loadableArgTI,
1269+
std::function<Explosion(unsigned index, unsigned size)>
1270+
explosionForArgument) override {
1271+
Explosion paramValues;
1272+
// If the explosion must be passed indirectly, load the value from the
1273+
// indirect address.
1274+
auto &nativeSchema = loadableArgTI.nativeParameterValueSchema(IGF.IGM);
1275+
if (nativeSchema.requiresIndirect()) {
1276+
Explosion paramExplosion = explosionForArgument(index, 1);
1277+
Address paramAddr =
1278+
loadableParamTI.getAddressForPointer(paramExplosion.claimNext());
1279+
if (loadableParamTI.getStorageType() != loadableArgTI.getStorageType())
1280+
paramAddr =
1281+
loadableArgTI.getAddressForPointer(IGF.Builder.CreateBitCast(
1282+
paramAddr.getAddress(),
1283+
loadableArgTI.getStorageType()->getPointerTo()));
1284+
loadableArgTI.loadAsTake(IGF, paramAddr, paramValues);
1285+
} else {
1286+
if (!nativeSchema.empty()) {
1287+
// Otherwise, we map from the native convention to the type's explosion
1288+
// schema.
1289+
Explosion nativeParam;
1290+
unsigned size = nativeSchema.size();
1291+
Explosion paramExplosion = explosionForArgument(index, size);
1292+
paramExplosion.transferInto(nativeParam, size);
1293+
paramValues = nativeSchema.mapFromNative(IGF.IGM, IGF, nativeParam,
1294+
param->getType());
1295+
} else {
1296+
assert(loadableParamTI.getSchema().empty());
1297+
}
1298+
}
1299+
return paramValues;
1300+
};
12651301

12661302
public:
12671303
using SyncEntryPointArgumentEmission::requiresIndirectResult;
@@ -1312,10 +1348,8 @@ class AsyncNativeCCEntryPointArgumentEmission final
13121348
return loadValue(contextLayout);
13131349
}
13141350
Explosion getArgumentExplosion(unsigned index, unsigned size) override {
1315-
assert(size > 0);
13161351
auto argumentLayout = layout.getArgumentLayout(index);
13171352
auto result = loadExplosion(argumentLayout);
1318-
assert(result.size() == size);
13191353
return result;
13201354
}
13211355
bool requiresIndirectResult(SILType retType) override { return false; }
@@ -1380,6 +1414,14 @@ class AsyncNativeCCEntryPointArgumentEmission final
13801414
llvm_unreachable(
13811415
"async functions do not use a fixed size coroutine buffer");
13821416
}
1417+
Explosion
1418+
explosionForObject(IRGenFunction &IGF, unsigned index, SILArgument *param,
1419+
SILType paramTy, const LoadableTypeInfo &loadableParamTI,
1420+
const LoadableTypeInfo &loadableArgTI,
1421+
std::function<Explosion(unsigned index, unsigned size)>
1422+
explosionForArgument) override {
1423+
return explosionForArgument(index, 1);
1424+
};
13831425
};
13841426

13851427
std::unique_ptr<NativeCCEntryPointArgumentEmission>
@@ -1654,8 +1696,9 @@ static ArrayRef<SILArgument *> emitEntryPointIndirectReturn(
16541696
}
16551697

16561698
template <typename ExplosionForArgument>
1657-
static void bindParameter(IRGenSILFunction &IGF, unsigned index,
1658-
SILArgument *param, SILType paramTy,
1699+
static void bindParameter(IRGenSILFunction &IGF,
1700+
NativeCCEntryPointArgumentEmission &emission,
1701+
unsigned index, SILArgument *param, SILType paramTy,
16591702
ExplosionForArgument explosionForArgument) {
16601703
// Pull out the parameter value and its formal type.
16611704
auto &paramTI = IGF.getTypeInfo(IGF.CurSILFn->mapTypeIntoContext(paramTy));
@@ -1664,36 +1707,11 @@ static void bindParameter(IRGenSILFunction &IGF, unsigned index,
16641707
// If the SIL parameter isn't passed indirectly, we need to map it
16651708
// to an explosion.
16661709
if (param->getType().isObject()) {
1667-
Explosion paramValues;
16681710
auto &loadableParamTI = cast<LoadableTypeInfo>(paramTI);
1669-
auto &loadableArgTI = cast<LoadableTypeInfo>(paramTI);
1670-
// If the explosion must be passed indirectly, load the value from the
1671-
// indirect address.
1672-
auto &nativeSchema = argTI.nativeParameterValueSchema(IGF.IGM);
1673-
if (nativeSchema.requiresIndirect()) {
1674-
Explosion paramExplosion = explosionForArgument(index, 1);
1675-
Address paramAddr =
1676-
loadableParamTI.getAddressForPointer(paramExplosion.claimNext());
1677-
if (paramTI.getStorageType() != argTI.getStorageType())
1678-
paramAddr =
1679-
loadableArgTI.getAddressForPointer(IGF.Builder.CreateBitCast(
1680-
paramAddr.getAddress(),
1681-
loadableArgTI.getStorageType()->getPointerTo()));
1682-
loadableArgTI.loadAsTake(IGF, paramAddr, paramValues);
1683-
} else {
1684-
if (!nativeSchema.empty()) {
1685-
// Otherwise, we map from the native convention to the type's explosion
1686-
// schema.
1687-
Explosion nativeParam;
1688-
unsigned size = nativeSchema.size();
1689-
Explosion paramExplosion = explosionForArgument(index, size);
1690-
paramExplosion.transferInto(nativeParam, size);
1691-
paramValues = nativeSchema.mapFromNative(IGF.IGM, IGF, nativeParam,
1692-
param->getType());
1693-
} else {
1694-
assert(paramTI.getSchema().empty());
1695-
}
1696-
}
1711+
auto &loadableArgTI = cast<LoadableTypeInfo>(argTI);
1712+
auto paramValues =
1713+
emission.explosionForObject(IGF, index, param, paramTy, loadableParamTI,
1714+
loadableArgTI, explosionForArgument);
16971715
IGF.setLoweredExplosion(param, paramValues);
16981716
return;
16991717
}
@@ -1760,7 +1778,7 @@ static void emitEntryPointArgumentsNativeCC(IRGenSILFunction &IGF,
17601778
params = params.drop_back();
17611779

17621780
bindParameter(
1763-
IGF, 0, selfParam,
1781+
IGF, *emission, 0, selfParam,
17641782
conv.getSILArgumentType(conv.getNumSILArguments() - 1,
17651783
IGF.IGM.getMaximalTypeExpansionContext()),
17661784
[&](unsigned startIndex, unsigned size) {
@@ -1784,7 +1802,7 @@ static void emitEntryPointArgumentsNativeCC(IRGenSILFunction &IGF,
17841802
unsigned i = 0;
17851803
for (SILArgument *param : params) {
17861804
auto argIdx = conv.getSILArgIndexOfFirstParam() + i;
1787-
bindParameter(IGF, i, param,
1805+
bindParameter(IGF, *emission, i, param,
17881806
conv.getSILArgumentType(
17891807
argIdx, IGF.IGM.getMaximalTypeExpansionContext()),
17901808
[&](unsigned index, unsigned size) {
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift-dylib(%t/%target-library-name(PrintShims)) %S/../../Inputs/print-shims.swift -module-name PrintShims -emit-module -emit-module-path %t/PrintShims.swiftmodule
3+
// RUN: %target-codesign %t/%target-library-name(PrintShims)
4+
// RUN: %target-build-swift -Xfrontend -enable-experimental-concurrency -parse-sil %s -emit-ir -I %t -L %t -lPrintShim | %FileCheck %s --check-prefix=CHECK-LL
5+
// RUN: %target-build-swift -Xfrontend -enable-experimental-concurrency -parse-sil %s -module-name main -o %t/main -I %t -L %t -lPrintShims %target-rpath(%t)
6+
// RUN: %target-codesign %t/main
7+
// RUN: %target-run %t/main %t/%target-library-name(PrintShims) | %FileCheck %s
8+
9+
// REQUIRES: executable_test
10+
// REQUIRES: swift_test_mode_optimize_none
11+
// REQUIRES: concurrency
12+
// UNSUPPORTED: use_os_stdlib
13+
// UNSUPPORTED: CPU=arm64e
14+
15+
16+
import Builtin
17+
import Swift
18+
import PrintShims
19+
import _Concurrency
20+
21+
sil public_external @printGeneric : $@convention(thin) <T> (@in_guaranteed T) -> ()
22+
23+
struct Pack {
24+
public let a: Bool
25+
public let b: Bool
26+
public let c: Bool
27+
public let d: Bool
28+
public let e: Bool
29+
}
30+
31+
// CHECK-LL: @test_caseAD =
32+
33+
sil @structPackToVoid : $@async @convention(thin) (Pack) -> () {
34+
entry(%pack : $Pack):
35+
%pack_addr = alloc_stack $Pack
36+
store %pack to %pack_addr : $*Pack
37+
%printGeneric = function_ref @printGeneric : $@convention(thin) <T> (@in_guaranteed T) -> ()
38+
%printGeneric_result1 = apply %printGeneric<Pack>(%pack_addr) : $@convention(thin) <T> (@in_guaranteed T) -> () //CHECK: Pack(a: true, b: false, c: true, d: false, e: true)
39+
dealloc_stack %pack_addr : $*Pack
40+
41+
return %printGeneric_result1 : $()
42+
}
43+
44+
// CHECK-LL: define{{( dllexport)?}}{{( protected)?}} swiftcc void @test_case(%swift.task* {{%[0-9]+}}, %swift.executor* {{%[0-9]+}}, %swift.context* {{%[0-9]+}}) {{#[0-9]*}} {
45+
sil @test_case : $@async @convention(thin) () -> () {
46+
47+
%a_literal = integer_literal $Builtin.Int1, -1
48+
%a = struct $Bool (%a_literal : $Builtin.Int1)
49+
%b_literal = integer_literal $Builtin.Int1, 0
50+
%b = struct $Bool (%b_literal : $Builtin.Int1)
51+
%c_literal = integer_literal $Builtin.Int1, -1
52+
%c = struct $Bool (%c_literal : $Builtin.Int1)
53+
%d_literal = integer_literal $Builtin.Int1, 0
54+
%d = struct $Bool (%d_literal : $Builtin.Int1)
55+
%e_literal = integer_literal $Builtin.Int1, -1
56+
%e = struct $Bool (%a_literal : $Builtin.Int1)
57+
58+
%pack = struct $Pack (%a : $Bool, %b : $Bool, %c : $Bool, %d : $Bool, %e : $Bool)
59+
60+
%structPackToVoid = function_ref @structPackToVoid : $@async @convention(thin) (Pack) -> ()
61+
%result = apply %structPackToVoid(%pack) : $@async @convention(thin) (Pack) -> ()
62+
return %result : $()
63+
}
64+
65+
sil @main : $@async @convention(c) (Int32, UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>) -> Int32 {
66+
bb0(%0 : $Int32, %1 : $UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>):
67+
68+
%test_case = function_ref @test_case : $@convention(thin) @async () -> ()
69+
%result = apply %test_case() : $@convention(thin) @async () -> ()
70+
71+
%2 = integer_literal $Builtin.Int32, 0
72+
%3 = struct $Int32 (%2 : $Builtin.Int32)
73+
return %3 : $Int32
74+
}
75+
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %target-swift-frontend %s -emit-ir -enable-experimental-concurrency -parse-sil
2+
3+
import Swift
4+
5+
struct Pack {
6+
public let a: Bool
7+
public let b: Bool
8+
public let c: Bool
9+
public let d: Bool
10+
public let e: Bool
11+
}
12+
13+
public struct Strukt {
14+
public let rawValue: Int
15+
}
16+
17+
class Clazz { }
18+
19+
sil_vtable Clazz {
20+
}
21+
22+
sil @$foo : $@convention(method) @async (@guaranteed Pack, @guaranteed Array<Strukt>, @guaranteed Clazz) -> () {
23+
bb0(%0 : $Pack, %1 : $Array<Strukt>, %2 : $Clazz):
24+
%result = tuple ()
25+
return %result : $()
26+
}

0 commit comments

Comments
 (0)