Skip to content

Commit 43b6117

Browse files
authored
Merge pull request #68331 from hamishknight/linked-in
2 parents 9c3f2aa + 0760290 commit 43b6117

File tree

4 files changed

+124
-44
lines changed

4 files changed

+124
-44
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6238,18 +6238,29 @@ class ConjunctionElementProducer : public BindingProducer<ConjunctionElement> {
62386238
}
62396239
};
62406240

6241-
/// Find any references to not yet resolved outer VarDecls (including closure
6242-
/// parameters) used in the body of a conjunction element (e.g closures, taps,
6243-
/// if/switch expressions). This is required because isolated conjunctions, just
6244-
/// like single-expression closures, have to be connected to type variables they
6245-
/// are going to use, otherwise they'll get placed in a separate solver
6246-
/// component and would never produce a solution.
6247-
class VarRefCollector : public ASTWalker {
6241+
/// Find any references to external type variables used in the body of a
6242+
/// conjunction element (e.g closures, taps, if/switch expressions).
6243+
///
6244+
/// This includes:
6245+
/// - Not yet resolved outer VarDecls (including closure parameters)
6246+
/// - Return statements with a contextual type that has not yet been resolved
6247+
///
6248+
/// This is required because isolated conjunctions, just like single-expression
6249+
/// closures, have to be connected to type variables they are going to use,
6250+
/// otherwise they'll get placed in a separate solver component and would never
6251+
/// produce a solution.
6252+
class TypeVarRefCollector : public ASTWalker {
62486253
ConstraintSystem &CS;
6254+
DeclContext *DC;
6255+
ConstraintLocator *Locator;
6256+
62496257
llvm::SmallSetVector<TypeVariableType *, 4> TypeVars;
6258+
unsigned DCDepth = 0;
62506259

62516260
public:
6252-
VarRefCollector(ConstraintSystem &cs) : CS(cs) {}
6261+
TypeVarRefCollector(ConstraintSystem &cs, DeclContext *dc,
6262+
ConstraintLocator *locator)
6263+
: CS(cs), DC(dc), Locator(locator) {}
62536264

62546265
/// Infer the referenced type variables from a given decl.
62556266
void inferTypeVars(Decl *D);
@@ -6259,6 +6270,8 @@ class VarRefCollector : public ASTWalker {
62596270
}
62606271

62616272
PreWalkResult<Expr *> walkToExprPre(Expr *expr) override;
6273+
PostWalkResult<Expr *> walkToExprPost(Expr *expr) override;
6274+
PreWalkResult<Stmt *> walkToStmtPre(Stmt *stmt) override;
62626275

62636276
PreWalkAction walkToDeclPre(Decl *D) override {
62646277
// We only need to walk into PatternBindingDecls, other kinds of decls

lib/Sema/CSGen.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ namespace {
866866
};
867867
} // end anonymous namespace
868868

869-
void VarRefCollector::inferTypeVars(Decl *D) {
869+
void TypeVarRefCollector::inferTypeVars(Decl *D) {
870870
// We're only interested in VarDecls.
871871
if (!isa_and_nonnull<VarDecl>(D))
872872
return;
@@ -881,7 +881,10 @@ void VarRefCollector::inferTypeVars(Decl *D) {
881881
}
882882

883883
ASTWalker::PreWalkResult<Expr *>
884-
VarRefCollector::walkToExprPre(Expr *expr) {
884+
TypeVarRefCollector::walkToExprPre(Expr *expr) {
885+
if (isa<ClosureExpr>(expr))
886+
DCDepth += 1;
887+
885888
if (auto *DRE = dyn_cast<DeclRefExpr>(expr))
886889
inferTypeVars(DRE->getDecl());
887890

@@ -901,6 +904,34 @@ VarRefCollector::walkToExprPre(Expr *expr) {
901904
return Action::Continue(expr);
902905
}
903906

907+
ASTWalker::PostWalkResult<Expr *>
908+
TypeVarRefCollector::walkToExprPost(Expr *expr) {
909+
if (isa<ClosureExpr>(expr))
910+
DCDepth -= 1;
911+
912+
return Action::Continue(expr);
913+
}
914+
915+
ASTWalker::PreWalkResult<Stmt *>
916+
TypeVarRefCollector::walkToStmtPre(Stmt *stmt) {
917+
// If we have a return without any intermediate DeclContexts in a ClosureExpr,
918+
// we need to include any type variables in the closure's result type, since
919+
// the conjunction will generate constraints using that type. We don't need to
920+
// connect to returns in e.g nested closures since we'll connect those when we
921+
// generate constraints for those closures. We also don't need to bother if
922+
// we're generating constraints for the closure itself, since we'll connect
923+
// the conjunction to the closure type variable itself.
924+
if (auto *CE = dyn_cast<ClosureExpr>(DC)) {
925+
if (isa<ReturnStmt>(stmt) && DCDepth == 0 &&
926+
!Locator->directlyAt<ClosureExpr>()) {
927+
SmallPtrSet<TypeVariableType *, 4> typeVars;
928+
CS.getClosureType(CE)->getResult()->getTypeVariables(typeVars);
929+
TypeVars.insert(typeVars.begin(), typeVars.end());
930+
}
931+
}
932+
return Action::Continue(stmt);
933+
}
934+
904935
namespace {
905936
class ConstraintGenerator : public ExprVisitor<ConstraintGenerator, Type> {
906937
ConstraintSystem &CS;
@@ -1304,7 +1335,8 @@ namespace {
13041335
// in the tap body, otherwise tap expression is going
13051336
// to get disconnected from the context.
13061337
if (auto *body = tap->getBody()) {
1307-
VarRefCollector refCollector(CS);
1338+
TypeVarRefCollector refCollector(
1339+
CS, tap->getVar()->getDeclContext(), locator);
13081340

13091341
body->walk(refCollector);
13101342

@@ -2938,7 +2970,7 @@ namespace {
29382970
auto *locator = CS.getConstraintLocator(closure);
29392971
auto closureType = CS.createTypeVariable(locator, TVO_CanBindToNoEscape);
29402972

2941-
VarRefCollector refCollector(CS);
2973+
TypeVarRefCollector refCollector(CS, /*DC*/ closure, locator);
29422974
// Walk the capture list if this closure has one, because it could
29432975
// reference declarations from the outer closure.
29442976
if (auto *captureList =

lib/Sema/CSSyntacticElement.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,10 @@ static bool isViableElement(ASTNode element,
281281
using ElementInfo = std::tuple<ASTNode, ContextualTypeInfo,
282282
/*isDiscarded=*/bool, ConstraintLocator *>;
283283

284-
static void createConjunction(ConstraintSystem &cs,
284+
static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
285285
ArrayRef<ElementInfo> elements,
286-
ConstraintLocator *locator,
287-
bool isIsolated = false,
288-
ArrayRef<TypeVariableType *> extraTypeVars = {}) {
286+
ConstraintLocator *locator, bool isIsolated,
287+
ArrayRef<TypeVariableType *> extraTypeVars) {
289288
SmallVector<Constraint *, 4> constraints;
290289
SmallVector<TypeVariableType *, 2> referencedVars;
291290
referencedVars.append(extraTypeVars.begin(), extraTypeVars.end());
@@ -335,7 +334,7 @@ static void createConjunction(ConstraintSystem &cs,
335334
isIsolated = true;
336335
}
337336

338-
VarRefCollector paramCollector(cs);
337+
TypeVarRefCollector paramCollector(cs, dc, locator);
339338

340339
// Whether we're doing completion, and the conjunction is for a
341340
// SingleValueStmtExpr, or one of its braces.
@@ -517,6 +516,13 @@ class SyntacticElementConstraintGenerator
517516
ConstraintLocator *locator)
518517
: cs(cs), context(context), locator(locator) {}
519518

519+
void createConjunction(ArrayRef<ElementInfo> elements,
520+
ConstraintLocator *locator, bool isIsolated = false,
521+
ArrayRef<TypeVariableType *> extraTypeVars = {}) {
522+
::createConjunction(cs, context.getAsDeclContext(), elements, locator,
523+
isIsolated, extraTypeVars);
524+
}
525+
520526
void visitExprPattern(ExprPattern *EP) {
521527
auto target = SyntacticElementTarget::forExprPattern(EP);
522528

@@ -859,7 +865,7 @@ class SyntacticElementConstraintGenerator
859865
if (auto *join = context.ElementJoin.getPtrOrNull())
860866
elements.push_back(makeJoinElement(cs, join, locator));
861867

862-
createConjunction(cs, elements, locator);
868+
createConjunction(elements, locator);
863869
}
864870

865871
void visitGuardStmt(GuardStmt *guardStmt) {
@@ -868,7 +874,7 @@ class SyntacticElementConstraintGenerator
868874
visitStmtCondition(guardStmt, elements, locator);
869875
elements.push_back(makeElement(guardStmt->getBody(), locator));
870876

871-
createConjunction(cs, elements, locator);
877+
createConjunction(elements, locator);
872878
}
873879

874880
void visitWhileStmt(WhileStmt *whileStmt) {
@@ -877,16 +883,15 @@ class SyntacticElementConstraintGenerator
877883
visitStmtCondition(whileStmt, elements, locator);
878884
elements.push_back(makeElement(whileStmt->getBody(), locator));
879885

880-
createConjunction(cs, elements, locator);
886+
createConjunction(elements, locator);
881887
}
882888

883889
void visitDoStmt(DoStmt *doStmt) {
884890
visitBraceStmt(doStmt->getBody());
885891
}
886892

887893
void visitRepeatWhileStmt(RepeatWhileStmt *repeatWhileStmt) {
888-
createConjunction(cs,
889-
{makeElement(repeatWhileStmt->getCond(),
894+
createConjunction({makeElement(repeatWhileStmt->getCond(),
890895
cs.getConstraintLocator(
891896
locator, ConstraintLocator::Condition),
892897
getContextForCondition()),
@@ -895,8 +900,7 @@ class SyntacticElementConstraintGenerator
895900
}
896901

897902
void visitPoundAssertStmt(PoundAssertStmt *poundAssertStmt) {
898-
createConjunction(cs,
899-
{makeElement(poundAssertStmt->getCondition(),
903+
createConjunction({makeElement(poundAssertStmt->getCondition(),
900904
cs.getConstraintLocator(
901905
locator, ConstraintLocator::Condition),
902906
getContextForCondition())},
@@ -913,12 +917,10 @@ class SyntacticElementConstraintGenerator
913917
auto *errorExpr = throwStmt->getSubExpr();
914918

915919
createConjunction(
916-
cs,
917-
{makeElement(
918-
errorExpr,
919-
cs.getConstraintLocator(
920-
locator, LocatorPathElt::SyntacticElement(errorExpr)),
921-
{errType, CTP_ThrowStmt})},
920+
{makeElement(errorExpr,
921+
cs.getConstraintLocator(
922+
locator, LocatorPathElt::SyntacticElement(errorExpr)),
923+
{errType, CTP_ThrowStmt})},
922924
locator);
923925
}
924926

@@ -939,12 +941,10 @@ class SyntacticElementConstraintGenerator
939941
auto *selfExpr = discardStmt->getSubExpr();
940942

941943
createConjunction(
942-
cs,
943-
{makeElement(
944-
selfExpr,
945-
cs.getConstraintLocator(
946-
locator, LocatorPathElt::SyntacticElement(selfExpr)),
947-
{nominalType, CTP_DiscardStmt})},
944+
{makeElement(selfExpr,
945+
cs.getConstraintLocator(
946+
locator, LocatorPathElt::SyntacticElement(selfExpr)),
947+
{nominalType, CTP_DiscardStmt})},
948948
locator);
949949
}
950950

@@ -962,7 +962,7 @@ class SyntacticElementConstraintGenerator
962962
// Body of the `for-in` loop.
963963
elements.push_back(makeElement(forEachStmt->getBody(), stmtLoc));
964964

965-
createConjunction(cs, elements, locator);
965+
createConjunction(elements, locator);
966966
}
967967

968968
void visitSwitchStmt(SwitchStmt *switchStmt) {
@@ -990,7 +990,7 @@ class SyntacticElementConstraintGenerator
990990
if (auto *join = context.ElementJoin.getPtrOrNull())
991991
elements.push_back(makeJoinElement(cs, join, switchLoc));
992992

993-
createConjunction(cs, elements, switchLoc);
993+
createConjunction(elements, switchLoc);
994994
}
995995

996996
void visitDoCatchStmt(DoCatchStmt *doStmt) {
@@ -1007,7 +1007,7 @@ class SyntacticElementConstraintGenerator
10071007
for (auto *catchStmt : doStmt->getCatches())
10081008
elements.push_back(makeElement(catchStmt, doLoc));
10091009

1010-
createConjunction(cs, elements, doLoc);
1010+
createConjunction(elements, doLoc);
10111011
}
10121012

10131013
void visitCaseStmt(CaseStmt *caseStmt) {
@@ -1040,7 +1040,7 @@ class SyntacticElementConstraintGenerator
10401040

10411041
elements.push_back(makeElement(caseStmt->getBody(), caseLoc));
10421042

1043-
createConjunction(cs, elements, caseLoc);
1043+
createConjunction(elements, caseLoc);
10441044
}
10451045

10461046
void visitBraceStmt(BraceStmt *braceStmt) {
@@ -1127,7 +1127,7 @@ class SyntacticElementConstraintGenerator
11271127
// want to type-check captures to make sure that the context
11281128
// is valid.
11291129
if (captureList)
1130-
createConjunction(cs, elements, locator);
1130+
createConjunction(elements, locator);
11311131

11321132
return;
11331133
}
@@ -1195,7 +1195,7 @@ class SyntacticElementConstraintGenerator
11951195
contextInfo.value_or(ContextualTypeInfo()), isDiscarded));
11961196
}
11971197

1198-
createConjunction(cs, elements, locator);
1198+
createConjunction(elements, locator);
11991199
}
12001200

12011201
void visitReturnStmt(ReturnStmt *returnStmt) {
@@ -1282,7 +1282,7 @@ class SyntacticElementConstraintGenerator
12821282
auto resultElt = makeElement(resultExpr, locator,
12831283
contextInfo.value_or(ContextualTypeInfo()),
12841284
/*isDiscarded=*/false);
1285-
createConjunction(cs, {resultElt}, locator);
1285+
createConjunction({resultElt}, locator);
12861286
}
12871287

12881288
ContextualTypeInfo getContextualResultInfo() const {
@@ -1520,6 +1520,9 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
15201520

15211521
void ConstraintSystem::generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
15221522
ConstraintLocatorBuilder locator) {
1523+
assert(!exprPatterns.empty());
1524+
auto *DC = exprPatterns.front()->getDeclContext();
1525+
15231526
// Form a conjunction of ExprPattern elements, isolated from the rest of the
15241527
// pattern.
15251528
SmallVector<ElementInfo> elements;
@@ -1532,7 +1535,7 @@ void ConstraintSystem::generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
15321535
elements.push_back(makeElement(EP, getConstraintLocator(EP), context));
15331536
}
15341537
auto *loc = getConstraintLocator(locator);
1535-
createConjunction(*this, elements, loc, /*isIsolated*/ true,
1538+
createConjunction(*this, DC, elements, loc, /*isIsolated*/ true,
15361539
referencedTypeVars);
15371540
}
15381541

test/Constraints/rdar114402042.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
// rdar://114402042 - Make sure we connect the SingleValueStmtExpr to the outer
4+
// closure's return type.
5+
func foo<T>(_: () -> T) {}
6+
func bar<T>(_ x: T) {}
7+
func test() {
8+
foo {
9+
bar(if true { return } else { return })
10+
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
11+
// expected-error@-2 2{{cannot 'return' in 'if' when used as expression}}
12+
}
13+
foo {
14+
bar(if true { { return } } else { { return } })
15+
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
16+
}
17+
}
18+
19+
func baz() -> String? {
20+
nil
21+
}
22+
23+
var x: Int? = {
24+
print( // expected-error {{cannot convert value of type '()' to closure result type 'Int?'}}
25+
// expected-note@-1 {{to match this opening '('}}
26+
switch baz() {
27+
case ""?:
28+
return nil
29+
default:
30+
return nil
31+
}
32+
}() // expected-error {{expected ')' in expression list}}

0 commit comments

Comments
 (0)