Skip to content

Commit 619a517

Browse files
committed
[CSSimplify] Check all conditional requirements for a type variable
1 parent 4bd4fa6 commit 619a517

File tree

3 files changed

+56
-25
lines changed

3 files changed

+56
-25
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
#include "swift/AST/ParameterList.h"
3030
#include "swift/AST/PropertyWrappers.h"
3131
#include "swift/AST/ProtocolConformance.h"
32+
#include "swift/AST/Requirement.h"
3233
#include "swift/AST/SourceFile.h"
34+
#include "swift/AST/Types.h"
3335
#include "swift/Basic/StringExtras.h"
3436
#include "swift/ClangImporter/ClangModule.h"
3537
#include "swift/Sema/CSFix.h"
@@ -9527,12 +9529,28 @@ performMemberLookup(ConstraintKind constraintKind, DeclNameRef memberName,
95279529
auto sendableProtocol =
95289530
DC->getParentModule()->getASTContext().getProtocol(
95299531
KnownProtocolKind::Sendable);
9530-
auto baseSendable = swift::TypeChecker::conformsToProtocol(
9531-
instanceTy, sendableProtocol, DC->getParentModule());
9532-
9533-
if (!baseSendable.isInvalid() &&
9534-
!baseSendable.getConditionalRequirements().empty() &&
9535-
instanceTy->hasTypeVariable()) {
9532+
auto baseConformance = DC->getParentModule()->lookupConformance(
9533+
instanceTy, sendableProtocol);
9534+
9535+
if (llvm::any_of(
9536+
baseConformance.getConditionalRequirements(),
9537+
[&](const auto &req) {
9538+
switch (req.getKind()) {
9539+
case RequirementKind::Conformance: {
9540+
if (auto secondType =
9541+
req.getSecondType()->template getAs<ProtocolType>()) {
9542+
return req.getFirstType()->hasTypeVariable() &&
9543+
secondType->getDecl()->isSpecificProtocol(
9544+
KnownProtocolKind::Sendable);
9545+
}
9546+
}
9547+
case RequirementKind::Superclass:
9548+
case RequirementKind::SameType:
9549+
case RequirementKind::SameShape:
9550+
case RequirementKind::Layout:
9551+
return false;
9552+
}
9553+
})) {
95369554
result.OverallResult = MemberLookupResult::Unsolved;
95379555
return result;
95389556
}

lib/Sema/ConstraintSystem.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
16651665
ConstraintLocatorBuilder locator,
16661666
DeclContext *useDC) {
16671667
auto &ctx = getASTContext();
1668-
1668+
16691669
if (value->getDeclContext()->isTypeContext() && isa<FuncDecl>(value)) {
16701670
// Unqualified lookup can find operator names within nominal types.
16711671
auto func = cast<FuncDecl>(value);
@@ -1711,14 +1711,16 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
17111711
auto funcType = funcDecl->getInterfaceType()->castTo<AnyFunctionType>();
17121712
auto numLabelsToRemove = getNumRemovedArgumentLabels(
17131713
funcDecl, /*isCurriedInstanceReference=*/false, functionRefKind);
1714-
1714+
17151715
if (ctx.LangOpts.hasFeature(Feature::InferSendableMethods)) {
17161716
// All global functions should be @Sendable
1717-
if(funcDecl->getDeclContext()->isLocalContext()) {
1718-
funcType = funcType->withExtInfo(funcType->getExtInfo().withConcurrent())->getAs<AnyFunctionType>();
1717+
if (funcDecl->getDeclContext()->isLocalContext()) {
1718+
funcType =
1719+
funcType->withExtInfo(funcType->getExtInfo().withConcurrent())
1720+
->getAs<AnyFunctionType>();
17191721
}
17201722
}
1721-
1723+
17221724
auto openedType = openFunctionType(funcType, locator, replacements,
17231725
funcDecl->getDeclContext())
17241726
->removeArgumentLabels(numLabelsToRemove);
@@ -2641,26 +2643,27 @@ ConstraintSystem::getTypeOfMemberReference(
26412643
if (inferredSendable) {
26422644
auto sendableProtocol = parentModule->getASTContext().getProtocol(
26432645
KnownProtocolKind::Sendable);
2644-
auto baseConformance = TypeChecker::conformsToProtocol(
2645-
baseOpenedTy, sendableProtocol, parentModule);
2646+
auto baseConformance =
2647+
parentModule->lookupConformance(baseOpenedTy, sendableProtocol);
26462648

26472649
if (baseTypeSendable) {
26482650
// Add @Sendable to functions without conditional conformances
2649-
if (baseConformance.getConditionalRequirements().empty()){
2651+
if (baseConformance.getConditionalRequirements().empty()) {
26502652
functionType = functionType->withExtInfo(functionType->getExtInfo().withConcurrent())->getAs<FunctionType>();
26512653
} else {
26522654
// Handle Conditional Conformances
26532655
auto substitutionMap = SubstitutionMap::getProtocolSubstitutions(
2654-
sendableProtocol, baseOpenedTy,
2655-
baseConformance);
2656-
2657-
auto result = TypeChecker::checkGenericArguments(parentModule, baseConformance.getConditionalRequirements(), QuerySubstitutionMap{substitutionMap} );
2658-
2656+
sendableProtocol, baseOpenedTy, baseConformance);
2657+
2658+
auto result = TypeChecker::checkGenericArguments(
2659+
parentModule, baseConformance.getConditionalRequirements(),
2660+
QuerySubstitutionMap{substitutionMap});
2661+
26592662
if (result == CheckGenericArgumentsResult::Success) {
26602663
functionType =
2661-
functionType
2662-
->withExtInfo(functionType->getExtInfo().withConcurrent())
2663-
->getAs<FunctionType>();
2664+
functionType
2665+
->withExtInfo(functionType->getExtInfo().withConcurrent())
2666+
->getAs<FunctionType>();
26642667
}
26652668
}
26662669
}
@@ -2672,10 +2675,10 @@ ConstraintSystem::getTypeOfMemberReference(
26722675
FunctionType::get(fullFunctionType->getParams(), functionType, info);
26732676

26742677
// Add @Sendable to openedType if possible
2675-
if (inferredSendable){
2678+
if (inferredSendable) {
26762679
auto origFnType = openedType->castTo<FunctionType>();
26772680
openedType =
2678-
origFnType->withExtInfo(origFnType->getExtInfo().withConcurrent());
2681+
origFnType->withExtInfo(origFnType->getExtInfo().withConcurrent());
26792682
}
26802683
}
26812684

test/Concurrency/sendable_methods.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ enum InferredSendableE: P {
4141
}
4242

4343
struct GenericS<T>: P {
44-
func f() { }
44+
init(_: T) { }
45+
46+
func f() { }
4547
}
4648

4749
final class GenericC<T>: P {
@@ -80,6 +82,9 @@ g(GenericC<Int>.f)
8082
g(GenericS<NonSendable>.f) // expected-warning{{converting non-sendable function value to '@Sendable () -> Void' may introduce data races
8183
g(GenericC<NonSendable>.f) // expected-warning{{converting non-sendable function value to '@Sendable () -> Void' may introduce data races
8284

85+
g(GenericS(NonSendable()).f) // expected-warning{{converting non-sendable function value to '@Sendable () -> Void' may introduce data races
86+
g(GenericS(1).f)
87+
8388
func executeAsTask (_ f: @escaping @Sendable () -> Void) {
8489
Task {
8590
f()
@@ -133,3 +138,8 @@ func doWork() -> Int {
133138
let work: @Sendable () -> Int = doWork
134139
Task<Int, Never>.detached(priority: nil, operation: doWork)
135140
Task<Int, Never>.detached(priority: nil, operation: work)
141+
142+
func generic<T>(_: T) {
143+
Task<Int, Never>.detached(priority: nil, operation: T)
144+
}
145+
generic(GenericS<Int>.f) // generic argument for `T` should be inferred as `@escaping @Sendable`

0 commit comments

Comments
 (0)