Skip to content

Commit 0760290

Browse files
committed
[CS] Connect conjunctions with ReturnStmts to return type
Augment the TypeVarRefCollector such that it picks up any type variables present in the result type for a closure DeclContext when visiting a ReturnStmt. This ensures we correctly handle if/switch expressions that contain `return` statements. rdar://114402042
1 parent 5851c59 commit 0760290

File tree

4 files changed

+86
-10
lines changed

4 files changed

+86
-10
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6243,17 +6243,24 @@ class ConjunctionElementProducer : public BindingProducer<ConjunctionElement> {
62436243
///
62446244
/// This includes:
62456245
/// - Not yet resolved outer VarDecls (including closure parameters)
6246+
/// - Return statements with a contextual type that has not yet been resolved
62466247
///
62476248
/// This is required because isolated conjunctions, just like single-expression
62486249
/// closures, have to be connected to type variables they are going to use,
62496250
/// otherwise they'll get placed in a separate solver component and would never
62506251
/// produce a solution.
62516252
class TypeVarRefCollector : public ASTWalker {
62526253
ConstraintSystem &CS;
6254+
DeclContext *DC;
6255+
ConstraintLocator *Locator;
6256+
62536257
llvm::SmallSetVector<TypeVariableType *, 4> TypeVars;
6258+
unsigned DCDepth = 0;
62546259

62556260
public:
6256-
TypeVarRefCollector(ConstraintSystem &cs) : CS(cs) {}
6261+
TypeVarRefCollector(ConstraintSystem &cs, DeclContext *dc,
6262+
ConstraintLocator *locator)
6263+
: CS(cs), DC(dc), Locator(locator) {}
62576264

62586265
/// Infer the referenced type variables from a given decl.
62596266
void inferTypeVars(Decl *D);
@@ -6263,6 +6270,8 @@ class TypeVarRefCollector : public ASTWalker {
62636270
}
62646271

62656272
PreWalkResult<Expr *> walkToExprPre(Expr *expr) override;
6273+
PostWalkResult<Expr *> walkToExprPost(Expr *expr) override;
6274+
PreWalkResult<Stmt *> walkToStmtPre(Stmt *stmt) override;
62666275

62676276
PreWalkAction walkToDeclPre(Decl *D) override {
62686277
// We only need to walk into PatternBindingDecls, other kinds of decls

lib/Sema/CSGen.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,9 @@ void TypeVarRefCollector::inferTypeVars(Decl *D) {
882882

883883
ASTWalker::PreWalkResult<Expr *>
884884
TypeVarRefCollector::walkToExprPre(Expr *expr) {
885+
if (isa<ClosureExpr>(expr))
886+
DCDepth += 1;
887+
885888
if (auto *DRE = dyn_cast<DeclRefExpr>(expr))
886889
inferTypeVars(DRE->getDecl());
887890

@@ -901,6 +904,34 @@ TypeVarRefCollector::walkToExprPre(Expr *expr) {
901904
return Action::Continue(expr);
902905
}
903906

907+
ASTWalker::PostWalkResult<Expr *>
908+
TypeVarRefCollector::walkToExprPost(Expr *expr) {
909+
if (isa<ClosureExpr>(expr))
910+
DCDepth -= 1;
911+
912+
return Action::Continue(expr);
913+
}
914+
915+
ASTWalker::PreWalkResult<Stmt *>
916+
TypeVarRefCollector::walkToStmtPre(Stmt *stmt) {
917+
// If we have a return without any intermediate DeclContexts in a ClosureExpr,
918+
// we need to include any type variables in the closure's result type, since
919+
// the conjunction will generate constraints using that type. We don't need to
920+
// connect to returns in e.g nested closures since we'll connect those when we
921+
// generate constraints for those closures. We also don't need to bother if
922+
// we're generating constraints for the closure itself, since we'll connect
923+
// the conjunction to the closure type variable itself.
924+
if (auto *CE = dyn_cast<ClosureExpr>(DC)) {
925+
if (isa<ReturnStmt>(stmt) && DCDepth == 0 &&
926+
!Locator->directlyAt<ClosureExpr>()) {
927+
SmallPtrSet<TypeVariableType *, 4> typeVars;
928+
CS.getClosureType(CE)->getResult()->getTypeVariables(typeVars);
929+
TypeVars.insert(typeVars.begin(), typeVars.end());
930+
}
931+
}
932+
return Action::Continue(stmt);
933+
}
934+
904935
namespace {
905936
class ConstraintGenerator : public ExprVisitor<ConstraintGenerator, Type> {
906937
ConstraintSystem &CS;
@@ -1304,7 +1335,8 @@ namespace {
13041335
// in the tap body, otherwise tap expression is going
13051336
// to get disconnected from the context.
13061337
if (auto *body = tap->getBody()) {
1307-
TypeVarRefCollector refCollector(CS);
1338+
TypeVarRefCollector refCollector(
1339+
CS, tap->getVar()->getDeclContext(), locator);
13081340

13091341
body->walk(refCollector);
13101342

@@ -2938,7 +2970,7 @@ namespace {
29382970
auto *locator = CS.getConstraintLocator(closure);
29392971
auto closureType = CS.createTypeVariable(locator, TVO_CanBindToNoEscape);
29402972

2941-
TypeVarRefCollector refCollector(CS);
2973+
TypeVarRefCollector refCollector(CS, /*DC*/ closure, locator);
29422974
// Walk the capture list if this closure has one, because it could
29432975
// reference declarations from the outer closure.
29442976
if (auto *captureList =

lib/Sema/CSSyntacticElement.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,10 @@ static bool isViableElement(ASTNode element,
281281
using ElementInfo = std::tuple<ASTNode, ContextualTypeInfo,
282282
/*isDiscarded=*/bool, ConstraintLocator *>;
283283

284-
static void createConjunction(ConstraintSystem &cs,
284+
static void createConjunction(ConstraintSystem &cs, DeclContext *dc,
285285
ArrayRef<ElementInfo> elements,
286-
ConstraintLocator *locator,
287-
bool isIsolated = false,
288-
ArrayRef<TypeVariableType *> extraTypeVars = {}) {
286+
ConstraintLocator *locator, bool isIsolated,
287+
ArrayRef<TypeVariableType *> extraTypeVars) {
289288
SmallVector<Constraint *, 4> constraints;
290289
SmallVector<TypeVariableType *, 2> referencedVars;
291290
referencedVars.append(extraTypeVars.begin(), extraTypeVars.end());
@@ -335,7 +334,7 @@ static void createConjunction(ConstraintSystem &cs,
335334
isIsolated = true;
336335
}
337336

338-
TypeVarRefCollector paramCollector(cs);
337+
TypeVarRefCollector paramCollector(cs, dc, locator);
339338

340339
// Whether we're doing completion, and the conjunction is for a
341340
// SingleValueStmtExpr, or one of its braces.
@@ -520,7 +519,8 @@ class SyntacticElementConstraintGenerator
520519
void createConjunction(ArrayRef<ElementInfo> elements,
521520
ConstraintLocator *locator, bool isIsolated = false,
522521
ArrayRef<TypeVariableType *> extraTypeVars = {}) {
523-
::createConjunction(cs, elements, locator, isIsolated, extraTypeVars);
522+
::createConjunction(cs, context.getAsDeclContext(), elements, locator,
523+
isIsolated, extraTypeVars);
524524
}
525525

526526
void visitExprPattern(ExprPattern *EP) {
@@ -1520,6 +1520,9 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
15201520

15211521
void ConstraintSystem::generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
15221522
ConstraintLocatorBuilder locator) {
1523+
assert(!exprPatterns.empty());
1524+
auto *DC = exprPatterns.front()->getDeclContext();
1525+
15231526
// Form a conjunction of ExprPattern elements, isolated from the rest of the
15241527
// pattern.
15251528
SmallVector<ElementInfo> elements;
@@ -1532,7 +1535,7 @@ void ConstraintSystem::generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
15321535
elements.push_back(makeElement(EP, getConstraintLocator(EP), context));
15331536
}
15341537
auto *loc = getConstraintLocator(locator);
1535-
createConjunction(*this, elements, loc, /*isIsolated*/ true,
1538+
createConjunction(*this, DC, elements, loc, /*isIsolated*/ true,
15361539
referencedTypeVars);
15371540
}
15381541

test/Constraints/rdar114402042.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
// rdar://114402042 - Make sure we connect the SingleValueStmtExpr to the outer
4+
// closure's return type.
5+
func foo<T>(_: () -> T) {}
6+
func bar<T>(_ x: T) {}
7+
func test() {
8+
foo {
9+
bar(if true { return } else { return })
10+
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
11+
// expected-error@-2 2{{cannot 'return' in 'if' when used as expression}}
12+
}
13+
foo {
14+
bar(if true { { return } } else { { return } })
15+
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
16+
}
17+
}
18+
19+
func baz() -> String? {
20+
nil
21+
}
22+
23+
var x: Int? = {
24+
print( // expected-error {{cannot convert value of type '()' to closure result type 'Int?'}}
25+
// expected-note@-1 {{to match this opening '('}}
26+
switch baz() {
27+
case ""?:
28+
return nil
29+
default:
30+
return nil
31+
}
32+
}() // expected-error {{expected ')' in expression list}}

0 commit comments

Comments
 (0)