Skip to content

Commit 147274b

Browse files
authored
Merge pull request swiftlang#33399 from DougGregor/concurrency-async-call-restrictions
[Concurrency] Implement restrictions on calls to 'async' functions.
2 parents 4391f42 + 40c12cc commit 147274b

File tree

3 files changed

+153
-22
lines changed

3 files changed

+153
-22
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4038,8 +4038,24 @@ NOTE(note_disable_error_propagation,none,
40384038
"did you mean to disable error propagation?", ())
40394039
ERROR(async_call_without_await,none,
40404040
"call is 'async' but is not marked with 'await'", ())
4041+
ERROR(async_call_without_await_in_autoclosure,none,
4042+
"call is 'async' in an autoclosure argument is not marked with 'await'", ())
40414043
WARNING(no_async_in_await,none,
40424044
"no calls to 'async' functions occur within 'await' expression", ())
4045+
ERROR(async_call_in_illegal_context,none,
4046+
"'async' call cannot occur in "
4047+
"%select{<<ERROR>>|a default argument|a property initializer|a global variable initializer|an enum case raw value|a catch pattern|a catch guard expression|a defer body}0",
4048+
(unsigned))
4049+
ERROR(await_in_illegal_context,none,
4050+
"'await' operation cannot occur in "
4051+
"%select{<<ERROR>>|a default argument|a property initializer|a global variable initializer|an enum case raw value|a catch pattern|a catch guard expression|a defer body}0",
4052+
(unsigned))
4053+
ERROR(async_in_nonasync_function,none,
4054+
"%select{'async'|'await'}0 in %select{a function|an autoclosure}1 that "
4055+
"does not support concurrency",
4056+
(bool, bool))
4057+
NOTE(note_add_async_to_function,none,
4058+
"add 'async' to function %0 to make it asynchronous", (DeclName))
40434059

40444060
WARNING(no_throw_in_try,none,
40454061
"no calls to throwing functions occur within 'try' expression", ())

lib/Sema/TypeCheckEffects.cpp

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ class Context {
857857
Kind TheKind;
858858
Optional<AnyFunctionRef> Function;
859859
bool HandlesErrors = false;
860+
bool HandlesAsync = false;
860861

861862
/// Whether error-handling queries should ignore the function context, e.g.,
862863
/// for autoclosure and rethrows checks.
@@ -870,9 +871,10 @@ class Context {
870871
assert(TheKind != Kind::PotentiallyHandled);
871872
}
872873

873-
explicit Context(bool handlesErrors, Optional<AnyFunctionRef> function)
874+
explicit Context(bool handlesErrors, bool handlesAsync,
875+
Optional<AnyFunctionRef> function)
874876
: TheKind(Kind::PotentiallyHandled), Function(function),
875-
HandlesErrors(handlesErrors) { }
877+
HandlesErrors(handlesErrors), HandlesAsync(handlesAsync) { }
876878

877879
public:
878880
/// Whether this is a function that rethrows.
@@ -910,7 +912,7 @@ class Context {
910912

911913
static Context forTopLevelCode(TopLevelCodeDecl *D) {
912914
// Top-level code implicitly handles errors and 'async' calls.
913-
return Context(/*handlesErrors=*/true, None);
915+
return Context(/*handlesErrors=*/true, /*handlesAsync=*/true, None);
914916
}
915917

916918
static Context forFunction(AbstractFunctionDecl *D) {
@@ -930,8 +932,7 @@ class Context {
930932
}
931933
}
932934

933-
bool handlesErrors = D->hasThrows();
934-
return Context(handlesErrors, AnyFunctionRef(D));
935+
return Context(D->hasThrows(), D->hasAsync(), AnyFunctionRef(D));
935936
}
936937

937938
static Context forDeferBody() {
@@ -956,12 +957,15 @@ class Context {
956957
static Context forClosure(AbstractClosureExpr *E) {
957958
// Determine whether the closure has throwing function type.
958959
bool closureTypeThrows = true;
960+
bool closureTypeIsAsync = true;
959961
if (auto closureType = E->getType()) {
960-
if (auto fnType = closureType->getAs<AnyFunctionType>())
962+
if (auto fnType = closureType->getAs<AnyFunctionType>()) {
961963
closureTypeThrows = fnType->isThrowing();
964+
closureTypeIsAsync = fnType->isAsync();
965+
}
962966
}
963967

964-
return Context(closureTypeThrows, AnyFunctionRef(E));
968+
return Context(closureTypeThrows, closureTypeIsAsync, AnyFunctionRef(E));
965969
}
966970

967971
static Context forCatchPattern(CaseStmt *S) {
@@ -1013,6 +1017,10 @@ class Context {
10131017
llvm_unreachable("bad error kind");
10141018
}
10151019

1020+
bool handlesAsync() const {
1021+
return HandlesAsync;
1022+
}
1023+
10161024
DeclContext *getRethrowsDC() const {
10171025
if (!isRethrows())
10181026
return nullptr;
@@ -1182,7 +1190,6 @@ class Context {
11821190
case Kind::DeferBody:
11831191
diagnoseThrowInIllegalContext(Diags, E, getKind());
11841192
return;
1185-
11861193
}
11871194
llvm_unreachable("bad context kind");
11881195
}
@@ -1211,6 +1218,64 @@ class Context {
12111218
}
12121219
llvm_unreachable("bad context kind");
12131220
}
1221+
1222+
void diagnoseUncoveredAsyncSite(ASTContext &ctx, ASTNode node) {
1223+
SourceRange highlight;
1224+
1225+
// Generate more specific messages in some cases.
1226+
if (auto apply = dyn_cast_or_null<ApplyExpr>(node.dyn_cast<Expr*>()))
1227+
highlight = apply->getSourceRange();
1228+
1229+
auto diag = diag::async_call_without_await;
1230+
if (isAutoClosure())
1231+
diag = diag::async_call_without_await_in_autoclosure;
1232+
ctx.Diags.diagnose(node.getStartLoc(), diag)
1233+
.highlight(highlight);
1234+
}
1235+
1236+
void diagnoseAsyncInIllegalContext(DiagnosticEngine &Diags, ASTNode node) {
1237+
if (auto *e = node.dyn_cast<Expr*>()) {
1238+
if (isa<ApplyExpr>(e)) {
1239+
Diags.diagnose(e->getLoc(), diag::async_call_in_illegal_context,
1240+
static_cast<unsigned>(getKind()));
1241+
return;
1242+
}
1243+
}
1244+
1245+
Diags.diagnose(node.getStartLoc(), diag::await_in_illegal_context,
1246+
static_cast<unsigned>(getKind()));
1247+
}
1248+
1249+
void maybeAddAsyncNote(DiagnosticEngine &Diags) {
1250+
if (!Function)
1251+
return;
1252+
1253+
auto func = dyn_cast_or_null<FuncDecl>(Function->getAbstractFunctionDecl());
1254+
if (!func)
1255+
return;
1256+
1257+
func->diagnose(diag::note_add_async_to_function, func->getName());
1258+
}
1259+
1260+
void diagnoseUnhandledAsyncSite(DiagnosticEngine &Diags, ASTNode node) {
1261+
switch (getKind()) {
1262+
case Kind::PotentiallyHandled:
1263+
Diags.diagnose(node.getStartLoc(), diag::async_in_nonasync_function,
1264+
node.isExpr(ExprKind::Await), isAutoClosure());
1265+
maybeAddAsyncNote(Diags);
1266+
return;
1267+
1268+
case Kind::EnumElementInitializer:
1269+
case Kind::GlobalVarInitializer:
1270+
case Kind::IVarInitializer:
1271+
case Kind::DefaultArgument:
1272+
case Kind::CatchPattern:
1273+
case Kind::CatchGuard:
1274+
case Kind::DeferBody:
1275+
diagnoseAsyncInIllegalContext(Diags, node);
1276+
return;
1277+
}
1278+
}
12141279
};
12151280

12161281
/// A class to walk over a local context and validate the correctness
@@ -1322,6 +1387,12 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
13221387
Self.MaxThrowingKind = ThrowingKind::None;
13231388
}
13241389

1390+
void resetCoverageForAutoclosureBody() {
1391+
Self.Flags.clear(ContextFlags::IsAsyncCovered);
1392+
Self.Flags.clear(ContextFlags::HasAnyAsyncSite);
1393+
Self.Flags.clear(ContextFlags::HasAnyAwait);
1394+
}
1395+
13251396
void resetCoverageForDoCatch() {
13261397
Self.Flags.reset();
13271398
Self.MaxThrowingKind = ThrowingKind::None;
@@ -1409,6 +1480,7 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
14091480
ShouldRecurse_t checkAutoClosure(AutoClosureExpr *E) {
14101481
ContextScope scope(*this, Context::forClosure(E));
14111482
scope.enterSubFunction();
1483+
scope.resetCoverageForAutoclosureBody();
14121484
E->getBody()->walk(*this);
14131485
scope.preserveCoverageFromAutoclosureBody();
14141486
return ShouldNotRecurse;
@@ -1572,17 +1644,14 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
15721644
if (classification.isAsync()) {
15731645
// Remember that we've seen an async call.
15741646
Flags.set(ContextFlags::HasAnyAsyncSite);
1575-
1647+
1648+
// Diagnose async calls in a context that doesn't handle async.
1649+
if (!CurContext.handlesAsync()) {
1650+
CurContext.diagnoseUnhandledAsyncSite(Ctx.Diags, E);
1651+
}
15761652
// Diagnose async calls that are outside of an await context.
1577-
if (!Flags.has(ContextFlags::IsAsyncCovered)) {
1578-
SourceRange highlight;
1579-
1580-
// Generate more specific messages in some cases.
1581-
if (auto e = dyn_cast_or_null<ApplyExpr>(E.dyn_cast<Expr*>()))
1582-
highlight = e->getSourceRange();
1583-
1584-
Ctx.Diags.diagnose(E.getStartLoc(), diag::async_call_without_await)
1585-
.highlight(highlight);
1653+
else if (!Flags.has(ContextFlags::IsAsyncCovered)) {
1654+
CurContext.diagnoseUncoveredAsyncSite(Ctx, E);
15861655
}
15871656
}
15881657

@@ -1626,10 +1695,16 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
16261695
scope.enterAwait();
16271696

16281697
E->getSubExpr()->walk(*this);
1629-
1630-
// Warn about 'await' expressions that weren't actually needed.
1631-
if (!Flags.has(ContextFlags::HasAnyAsyncSite))
1632-
Ctx.Diags.diagnose(E->getAwaitLoc(), diag::no_async_in_await);
1698+
1699+
// Warn about 'await' expressions that weren't actually needed, unless of
1700+
// course we're in a context that could never handle an 'async'. Then, we
1701+
// produce an error.
1702+
if (!Flags.has(ContextFlags::HasAnyAsyncSite)) {
1703+
if (CurContext.handlesAsync())
1704+
Ctx.Diags.diagnose(E->getAwaitLoc(), diag::no_async_in_await);
1705+
else
1706+
CurContext.diagnoseUnhandledAsyncSite(Ctx.Diags, E);
1707+
}
16331708

16341709
// Inform the parent of the walk that an 'await' exists here.
16351710
scope.preserveCoverageFromAwaitOperand();

test/expr/unary/async_await.swift

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,43 @@ func test1(asyncfp : () async -> Int, fp : () -> Int) async {
88
_ = asyncfp() // expected-error {{call is 'async' but is not marked with 'await'}}
99
}
1010

11+
func getInt() async -> Int { return 5 }
12+
13+
// Locations where "await" is prohibited.
14+
func test2(
15+
defaulted: Int = __await getInt() // expected-error{{'async' call cannot occur in a default argument}}
16+
) async {
17+
defer {
18+
_ = __await getInt() // expected-error{{'async' call cannot occur in a defer body}}
19+
}
20+
print("foo")
21+
}
22+
23+
func test3() { // expected-note{{add 'async' to function 'test3()' to make it asynchronous}}
24+
_ = __await getInt() // expected-error{{'async' in a function that does not support concurrency}}
25+
}
26+
27+
enum SomeEnum: Int {
28+
case foo = __await 5 // expected-error{{raw value for enum case must be a literal}}
29+
}
30+
31+
struct SomeStruct {
32+
var x = __await getInt() // expected-error{{'async' call cannot occur in a property initializer}}
33+
static var y = __await getInt() // expected-error{{'async' call cannot occur in a global variable initializer}}
34+
}
35+
36+
func acceptAutoclosureNonAsync(_: @autoclosure () -> Int) { }
37+
func acceptAutoclosureAsync(_: @autoclosure () async -> Int) { }
38+
39+
func testAutoclosure() async {
40+
acceptAutoclosureAsync(getInt()) // expected-error{{call is 'async' in an autoclosure argument is not marked with 'await'}}
41+
acceptAutoclosureNonAsync(getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
42+
43+
acceptAutoclosureAsync(__await getInt())
44+
acceptAutoclosureNonAsync(__await getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
45+
46+
__await acceptAutoclosureAsync(getInt()) // expected-error{{call is 'async' in an autoclosure argument is not marked with 'await'}}
47+
// expected-warning@-1{{no calls to 'async' functions occur within 'await' expression}}
48+
__await acceptAutoclosureNonAsync(getInt()) // expected-error{{'async' in an autoclosure that does not support concurrency}}
49+
// expected-warning@-1{{no calls to 'async' functions occur within 'await' expression}}
50+
}

0 commit comments

Comments
 (0)