Skip to content

Commit a0f506a

Browse files
committed
Use the closure type during type checking for establishing use of new features
When evaluating whether code is within a closure that uses concurrency features, use the type of the closure as it's known during type checking, so that contextual information (e.g., it's passed to a `@Sendable` or `async` parameter of function type) can affect the result. This corrects the definition for doing strict checking within a minimal context for the end result of the type-check, rather than it's initial state, catching more issues. Fixes SR-15131 / rdar://problem/82535088. (cherry picked from commit cc7904c)
1 parent aeaea11 commit a0f506a

File tree

6 files changed

+111
-53
lines changed

6 files changed

+111
-53
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,13 @@ enum class SolutionApplicationToFunctionResult {
22172217
Delay,
22182218
};
22192219

2220+
/// Retrieve the closure type from the constraint system.
2221+
struct GetClosureType {
2222+
ConstraintSystem &cs;
2223+
2224+
Type operator()(const AbstractClosureExpr *expr) const;
2225+
};
2226+
22202227
/// Describes a system of constraints on type variables, the
22212228
/// solution of which assigns concrete types to each of the type variables.
22222229
/// Constraint systems are typically generated given an (untyped) expression.
@@ -3096,9 +3103,16 @@ class ConstraintSystem {
30963103
}
30973104

30983105
FunctionType *getClosureType(const ClosureExpr *closure) const {
3106+
auto result = getClosureTypeIfAvailable(closure);
3107+
assert(result);
3108+
return result;
3109+
}
3110+
3111+
FunctionType *getClosureTypeIfAvailable(const ClosureExpr *closure) const {
30993112
auto result = ClosureTypes.find(closure);
3100-
assert(result != ClosureTypes.end());
3101-
return result->second;
3113+
if (result != ClosureTypes.end())
3114+
return result->second;
3115+
return nullptr;
31023116
}
31033117

31043118
TypeBase* getFavoredType(Expr *E) {
@@ -4116,6 +4130,12 @@ class ConstraintSystem {
41164130
ConstraintLocatorBuilder locator,
41174131
const OpenedTypeMap &replacements);
41184132

4133+
/// Wrapper over swift::adjustFunctionTypeForConcurrency that passes along
4134+
/// the appropriate closure-type extraction function.
4135+
AnyFunctionType *adjustFunctionTypeForConcurrency(
4136+
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
4137+
unsigned numApplies, bool isMainDispatchQueue);
4138+
41194139
/// Retrieve the type of a reference to the given value declaration.
41204140
///
41214141
/// For references to polymorphic function types, this routine "opens up"
@@ -4164,10 +4184,15 @@ class ConstraintSystem {
41644184
///
41654185
/// \param getType Optional callback to extract a type for given declaration.
41664186
static Type
4167-
getUnopenedTypeOfReference(VarDecl *value, Type baseType, DeclContext *UseDC,
4168-
llvm::function_ref<Type(VarDecl *)> getType,
4169-
ConstraintLocator *memberLocator = nullptr,
4170-
bool wantInterfaceType = false);
4187+
getUnopenedTypeOfReference(
4188+
VarDecl *value, Type baseType, DeclContext *UseDC,
4189+
llvm::function_ref<Type(VarDecl *)> getType,
4190+
ConstraintLocator *memberLocator = nullptr,
4191+
bool wantInterfaceType = false,
4192+
llvm::function_ref<Type(const AbstractClosureExpr *)> getClosureType =
4193+
[](const AbstractClosureExpr *) {
4194+
return Type();
4195+
});
41714196

41724197
/// Retrieve the type of a reference to the given value declaration,
41734198
/// as a member with a base of the given type.

lib/Sema/CSSimplify.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2171,7 +2171,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
21712171

21722172
/// Whether to downgrade to a concurrency warning.
21732173
auto isConcurrencyWarning = [&] {
2174-
if (contextRequiresStrictConcurrencyChecking(DC))
2174+
if (contextRequiresStrictConcurrencyChecking(DC, GetClosureType{*this}))
21752175
return false;
21762176

21772177
switch (kind) {

lib/Sema/ConstraintSystem.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,21 @@ doesStorageProduceLValue(AbstractStorageDecl *storage, Type baseType,
10981098
!storage->isSetterMutating());
10991099
}
11001100

1101+
Type GetClosureType::operator()(const AbstractClosureExpr *expr) const {
1102+
if (auto closure = dyn_cast<ClosureExpr>(expr)) {
1103+
// Look through type bindings, if we have them.
1104+
auto mutableClosure = const_cast<ClosureExpr *>(closure);
1105+
if (cs.hasType(mutableClosure)) {
1106+
return cs.getFixedTypeRecursive(
1107+
cs.getType(mutableClosure), /*wantRValue=*/true);
1108+
}
1109+
1110+
return cs.getClosureTypeIfAvailable(closure);
1111+
}
1112+
1113+
return Type();
1114+
}
1115+
11011116
Type ConstraintSystem::getUnopenedTypeOfReference(
11021117
VarDecl *value, Type baseType, DeclContext *UseDC,
11031118
ConstraintLocator *memberLocator, bool wantInterfaceType) {
@@ -1113,18 +1128,20 @@ Type ConstraintSystem::getUnopenedTypeOfReference(
11131128

11141129
return wantInterfaceType ? var->getInterfaceType() : var->getType();
11151130
},
1116-
memberLocator, wantInterfaceType);
1131+
memberLocator, wantInterfaceType, GetClosureType{*this});
11171132
}
11181133

11191134
Type ConstraintSystem::getUnopenedTypeOfReference(
11201135
VarDecl *value, Type baseType, DeclContext *UseDC,
11211136
llvm::function_ref<Type(VarDecl *)> getType,
1122-
ConstraintLocator *memberLocator, bool wantInterfaceType) {
1137+
ConstraintLocator *memberLocator, bool wantInterfaceType,
1138+
llvm::function_ref<Type(const AbstractClosureExpr *)> getClosureType) {
11231139
Type requestedType =
11241140
getType(value)->getWithoutSpecifierType()->getReferenceStorageReferent();
11251141

11261142
// Adjust the type for concurrency.
1127-
requestedType = adjustVarTypeForConcurrency(requestedType, value, UseDC);
1143+
requestedType = adjustVarTypeForConcurrency(
1144+
requestedType, value, UseDC, getClosureType);
11281145

11291146
// If we're dealing with contextual types, and we referenced this type from
11301147
// a different context, map the type.
@@ -1298,6 +1315,13 @@ static bool isRequirementOrWitness(const ConstraintLocatorBuilder &locator) {
12981315
return false;
12991316
}
13001317

1318+
AnyFunctionType *ConstraintSystem::adjustFunctionTypeForConcurrency(
1319+
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
1320+
unsigned numApplies, bool isMainDispatchQueue) {
1321+
return swift::adjustFunctionTypeForConcurrency(
1322+
fnType, decl, dc, numApplies, isMainDispatchQueue, GetClosureType{*this});
1323+
}
1324+
13011325
std::pair<Type, Type>
13021326
ConstraintSystem::getTypeOfReference(ValueDecl *value,
13031327
FunctionRefKind functionRefKind,
@@ -1314,7 +1338,8 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
13141338
->castTo<AnyFunctionType>();
13151339
if (!isRequirementOrWitness(locator)) {
13161340
unsigned numApplies = getNumApplications(value, false, functionRefKind);
1317-
funcType = adjustFunctionTypeForConcurrency(funcType, func, useDC, numApplies, false);
1341+
funcType = adjustFunctionTypeForConcurrency(
1342+
funcType, func, useDC, numApplies, false);
13181343
}
13191344
auto openedType = openFunctionType(
13201345
funcType, locator, replacements, func->getDeclContext());
@@ -2085,7 +2110,8 @@ Type ConstraintSystem::getEffectiveOverloadType(ConstraintLocator *locator,
20852110
} else if (type->hasDynamicSelfType()) {
20862111
type = withDynamicSelfResultReplaced(type, /*uncurryLevel=*/0);
20872112
}
2088-
type = adjustVarTypeForConcurrency(type, var, useDC);
2113+
type = adjustVarTypeForConcurrency(
2114+
type, var, useDC, GetClosureType{*this});
20892115
} else if (isa<AbstractFunctionDecl>(decl) || isa<EnumElementDecl>(decl)) {
20902116
if (decl->isInstanceMember() &&
20912117
(!overload.getBaseType() ||
@@ -2536,8 +2562,14 @@ bool ConstraintSystem::isAsynchronousContext(DeclContext *dc) {
25362562
if (auto func = dyn_cast<AbstractFunctionDecl>(dc))
25372563
return func->isAsyncContext();
25382564

2539-
if (auto closure = dyn_cast<AbstractClosureExpr>(dc))
2540-
return closure->isBodyAsync();
2565+
if (auto abstractClosure = dyn_cast<AbstractClosureExpr>(dc)) {
2566+
if (Type type = GetClosureType{*this}(abstractClosure)) {
2567+
if (auto fnType = type->getAs<AnyFunctionType>())
2568+
return fnType->isAsync();
2569+
}
2570+
2571+
return abstractClosure->isBodyAsync();
2572+
}
25412573

25422574
return false;
25432575
}

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,9 @@ static bool shouldDiagnoseExistingDataRaces(const DeclContext *dc) {
615615
if (dc->getParentModule()->isConcurrencyChecked())
616616
return true;
617617

618-
return contextRequiresStrictConcurrencyChecking(dc);
618+
return contextRequiresStrictConcurrencyChecking(dc, [](const AbstractClosureExpr *) {
619+
return Type();
620+
});
619621
}
620622

621623
/// Determine the default diagnostic behavior for this language mode.
@@ -1843,20 +1845,6 @@ namespace {
18431845
}
18441846
}
18451847

1846-
bool isInAsynchronousContext() const {
1847-
auto *dc = getDeclContext();
1848-
1849-
if (auto func = dyn_cast<AbstractFunctionDecl>(dc)) {
1850-
return func->isAsyncContext();
1851-
}
1852-
1853-
if (auto closure = dyn_cast<AbstractClosureExpr>(dc)) {
1854-
return closure->isBodyAsync();
1855-
}
1856-
1857-
return false;
1858-
}
1859-
18601848
enum class AsyncMarkingResult {
18611849
FoundAsync, // successfully marked an implicitly-async operation
18621850
NotFound, // fail: no valid implicitly-async operation was found
@@ -1917,7 +1905,7 @@ namespace {
19171905

19181906
if (auto declRef = dyn_cast_or_null<DeclRefExpr>(context)) {
19191907
if (usageEnv(declRef) == VarRefUseEnv::Read) {
1920-
if (!isInAsynchronousContext())
1908+
if (!getDeclContext()->isAsyncContext())
19211909
return AsyncMarkingResult::SyncContext;
19221910

19231911
declRef->setImplicitlyAsync(target);
@@ -1926,7 +1914,7 @@ namespace {
19261914
} else if (auto lookupExpr = dyn_cast_or_null<LookupExpr>(context)) {
19271915
if (usageEnv(lookupExpr) == VarRefUseEnv::Read) {
19281916

1929-
if (!isInAsynchronousContext())
1917+
if (!getDeclContext()->isAsyncContext())
19301918
return AsyncMarkingResult::SyncContext;
19311919

19321920
lookupExpr->setImplicitlyAsync(target);
@@ -1939,7 +1927,7 @@ namespace {
19391927
// actor-isolated non-isolated-self calls are implicitly async
19401928
// and thus OK.
19411929

1942-
if (!isInAsynchronousContext())
1930+
if (!getDeclContext()->isAsyncContext())
19431931
return AsyncMarkingResult::SyncContext;
19441932

19451933
isAsyncCall = true;
@@ -1955,7 +1943,7 @@ namespace {
19551943
auto concDecl = memberRef->first;
19561944
if (decl == concDecl.getDecl() && !apply->isImplicitlyAsync()) {
19571945

1958-
if (!isInAsynchronousContext())
1946+
if (!getDeclContext()->isAsyncContext())
19591947
return AsyncMarkingResult::SyncContext;
19601948

19611949
// then this ValueDecl appears as the called value of the ApplyExpr.
@@ -2068,7 +2056,7 @@ namespace {
20682056
return false;
20692057

20702058
// If we are not in an asynchronous context, complain.
2071-
if (!isInAsynchronousContext()) {
2059+
if (!getDeclContext()->isAsyncContext()) {
20722060
if (auto calleeDecl = apply->getCalledValue()) {
20732061
ctx.Diags.diagnose(
20742062
apply->getLoc(), diag::actor_isolated_call_decl,
@@ -3562,7 +3550,9 @@ void swift::checkOverrideActorIsolation(ValueDecl *value) {
35623550
overridden->diagnose(diag::overridden_here);
35633551
}
35643552

3565-
bool swift::contextRequiresStrictConcurrencyChecking(const DeclContext *dc) {
3553+
bool swift::contextRequiresStrictConcurrencyChecking(
3554+
const DeclContext *dc,
3555+
llvm::function_ref<Type(const AbstractClosureExpr *)> getType) {
35663556
// If Swift >= 6, everything uses strict concurrency checking.
35673557
if (dc->getASTContext().LangOpts.isSwiftVersionAtLeast(6))
35683558
return true;
@@ -3574,6 +3564,12 @@ bool swift::contextRequiresStrictConcurrencyChecking(const DeclContext *dc) {
35743564
if (auto explicitClosure = dyn_cast<ClosureExpr>(closure)) {
35753565
if (getExplicitGlobalActor(const_cast<ClosureExpr *>(explicitClosure)))
35763566
return true;
3567+
3568+
if (auto type = getType(closure)) {
3569+
if (auto fnType = type->getAs<AnyFunctionType>())
3570+
if (fnType->isAsync() || fnType->isSendable())
3571+
return true;
3572+
}
35773573
}
35783574

35793575
// Async and @Sendable closures use concurrency features.
@@ -4049,11 +4045,12 @@ static bool hasKnownUnsafeSendableFunctionParams(AbstractFunctionDecl *func) {
40494045
}
40504046

40514047
Type swift::adjustVarTypeForConcurrency(
4052-
Type type, VarDecl *var, DeclContext *dc) {
4048+
Type type, VarDecl *var, DeclContext *dc,
4049+
llvm::function_ref<Type(const AbstractClosureExpr *)> getType) {
40534050
if (!var->predatesConcurrency())
40544051
return type;
40554052

4056-
if (contextRequiresStrictConcurrencyChecking(dc))
4053+
if (contextRequiresStrictConcurrencyChecking(dc, getType))
40574054
return type;
40584055

40594056
bool isLValue = false;
@@ -4172,11 +4169,12 @@ static AnyFunctionType *applyUnsafeConcurrencyToFunctionType(
41724169

41734170
AnyFunctionType *swift::adjustFunctionTypeForConcurrency(
41744171
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
4175-
unsigned numApplies, bool isMainDispatchQueue) {
4172+
unsigned numApplies, bool isMainDispatchQueue,
4173+
llvm::function_ref<Type(const AbstractClosureExpr *)> getType) {
41764174
// Apply unsafe concurrency features to the given function type.
4175+
bool strictChecking = contextRequiresStrictConcurrencyChecking(dc, getType);
41774176
fnType = applyUnsafeConcurrencyToFunctionType(
4178-
fnType, decl, contextRequiresStrictConcurrencyChecking(dc), numApplies,
4179-
isMainDispatchQueue);
4177+
fnType, decl, strictChecking, numApplies, isMainDispatchQueue);
41804178

41814179
Type globalActorType;
41824180
if (decl) {
@@ -4190,7 +4188,7 @@ AnyFunctionType *swift::adjustFunctionTypeForConcurrency(
41904188
case ActorIsolation::GlobalActorUnsafe:
41914189
// Only treat as global-actor-qualified within code that has adopted
41924190
// Swift Concurrency features.
4193-
if (!contextRequiresStrictConcurrencyChecking(dc))
4191+
if (!strictChecking)
41944192
return fnType;
41954193

41964194
LLVM_FALLTHROUGH;
@@ -4232,7 +4230,10 @@ AnyFunctionType *swift::adjustFunctionTypeForConcurrency(
42324230
}
42334231

42344232
bool swift::completionContextUsesConcurrencyFeatures(const DeclContext *dc) {
4235-
return contextRequiresStrictConcurrencyChecking(dc);
4233+
return contextRequiresStrictConcurrencyChecking(
4234+
dc, [](const AbstractClosureExpr *) {
4235+
return Type();
4236+
});
42364237
}
42374238

42384239
AbstractFunctionDecl const *swift::isActorInitOrDeInitContext(

lib/Sema/TypeCheckConcurrency.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ void checkOverrideActorIsolation(ValueDecl *value);
214214
/// Determine whether the given context requires strict concurrency checking,
215215
/// e.g., because it uses concurrency features directly or because it's in
216216
/// code where strict checking has been enabled.
217-
bool contextRequiresStrictConcurrencyChecking(const DeclContext *dc);
217+
bool contextRequiresStrictConcurrencyChecking(
218+
const DeclContext *dc,
219+
llvm::function_ref<Type(const AbstractClosureExpr *)> getType);
218220

219221
/// Diagnose the presence of any non-sendable types when referencing a
220222
/// given declaration from a particular declaration context.
@@ -337,22 +339,19 @@ checkGlobalActorAttributes(
337339
Type getExplicitGlobalActor(ClosureExpr *closure);
338340

339341
/// Adjust the type of the variable for concurrency.
340-
Type adjustVarTypeForConcurrency(Type type, VarDecl *var, DeclContext *dc);
341-
342-
/// Adjust the function type of a function / subscript / enum case for
343-
/// concurrency.
344-
AnyFunctionType *adjustFunctionTypeForConcurrency(
345-
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
346-
unsigned numApplies, bool isMainDispatchQueue);
342+
Type adjustVarTypeForConcurrency(
343+
Type type, VarDecl *var, DeclContext *dc,
344+
llvm::function_ref<Type(const AbstractClosureExpr *)> getType);
347345

348346
/// Adjust the given function type to account for concurrency-specific
349347
/// attributes whose affect on the type might differ based on context.
350348
/// This includes adjustments for unsafe parameter attributes like
351349
/// `@_unsafeSendable` and `@_unsafeMainActor` as well as a global actor
352350
/// on the declaration itself.
353351
AnyFunctionType *adjustFunctionTypeForConcurrency(
354-
AnyFunctionType *fnType, ValueDecl *funcOrEnum, DeclContext *dc,
355-
unsigned numApplies, bool isMainDispatchQueue);
352+
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
353+
unsigned numApplies, bool isMainDispatchQueue,
354+
llvm::function_ref<Type(const AbstractClosureExpr *)> getType);
356355

357356
/// Determine whether the given name is that of a DispatchQueue operation that
358357
/// takes a closure to be executed on the queue.

test/Concurrency/global_actor_from_ordinary_context.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ class Sub: Super {
150150

151151
func g2() {
152152
Task.detached {
153-
self.f() // EXPECTED ERROR because f is on @SomeGlobalActor
153+
self.f() // expected-error{{expression is 'async' but is not marked with 'await'}}
154+
// expected-note@-1{{calls to instance method 'f()' from outside of its actor context are implicitly asynchronous}}
154155
}
155156
}
156157
}

0 commit comments

Comments
 (0)