Skip to content

Commit a22fd6f

Browse files
committed
[TypeChecker] Typecheck statement condition on demand
from NamingPatternRequest.
1 parent 3e200b7 commit a22fd6f

File tree

5 files changed

+68
-22
lines changed

5 files changed

+68
-22
lines changed

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,6 +2564,29 @@ bool TypeChecker::typeCheckStmtCondition(StmtCondition &cond, DeclContext *dc,
25642564
return false;
25652565
}
25662566

2567+
bool TypeChecker::typeCheckConditionForStatement(LabeledConditionalStmt *stmt,
2568+
DeclContext *dc) {
2569+
Diag<> diagnosticForAlwaysTrue = diag::invalid_diagnostic;
2570+
switch (stmt->getKind()) {
2571+
case StmtKind::If:
2572+
diagnosticForAlwaysTrue = diag::if_always_true;
2573+
break;
2574+
case StmtKind::While:
2575+
diagnosticForAlwaysTrue = diag::while_always_true;
2576+
break;
2577+
case StmtKind::Guard:
2578+
diagnosticForAlwaysTrue = diag::guard_always_succeeds;
2579+
break;
2580+
default:
2581+
llvm_unreachable("unknown LabeledConditionalStmt kind");
2582+
}
2583+
2584+
StmtCondition cond = stmt->getCond();
2585+
bool result = typeCheckStmtCondition(cond, dc, diagnosticForAlwaysTrue);
2586+
stmt->setCond(cond);
2587+
return result;
2588+
}
2589+
25672590
/// Find the '~=` operator that can compare an expression inside a pattern to a
25682591
/// value of a given type.
25692592
bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,

lib/Sema/TypeCheckDecl.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2360,23 +2360,33 @@ NamingPatternRequest::evaluate(Evaluator &evaluator, VarDecl *VD) const {
23602360
if (!namingPattern) {
23612361
auto *canVD = VD->getCanonicalVarDecl();
23622362
namingPattern = canVD->NamingPattern;
2363+
}
2364+
2365+
if (!namingPattern) {
2366+
// Try type checking parent conditional statement.
2367+
if (auto parentStmt = VD->getParentPatternStmt()) {
2368+
if (auto LCS = dyn_cast<LabeledConditionalStmt>(parentStmt)) {
2369+
TypeChecker::typeCheckConditionForStatement(LCS, VD->getDeclContext());
2370+
namingPattern = VD->NamingPattern;
2371+
}
2372+
}
2373+
}
23632374

2375+
if (!namingPattern) {
23642376
// HACK: If no other diagnostic applies, emit a generic diagnostic about
23652377
// a variable being unbound. We can't do better than this at the
23662378
// moment because TypeCheckPattern does not reliably invalidate parts of
23672379
// the pattern AST on failure.
23682380
//
23692381
// Once that's through, this will only fire during circular validation.
2370-
if (!namingPattern) {
2371-
if (VD->hasInterfaceType() &&
2372-
!VD->isInvalid() && !VD->getParentPattern()->isImplicit()) {
2373-
VD->diagnose(diag::variable_bound_by_no_pattern, VD->getName());
2374-
}
2375-
2376-
VD->getParentPattern()->setType(ErrorType::get(Context));
2377-
setBoundVarsTypeError(VD->getParentPattern(), Context);
2378-
return nullptr;
2382+
if (VD->hasInterfaceType() &&
2383+
!VD->isInvalid() && !VD->getParentPattern()->isImplicit()) {
2384+
VD->diagnose(diag::variable_bound_by_no_pattern, VD->getName());
23792385
}
2386+
2387+
VD->getParentPattern()->setType(ErrorType::get(Context));
2388+
setBoundVarsTypeError(VD->getParentPattern(), Context);
2389+
return nullptr;
23802390
}
23812391

23822392
if (!namingPattern->hasType()) {

lib/Sema/TypeCheckStmt.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
625625
}
626626

627627
Stmt *visitIfStmt(IfStmt *IS) {
628-
StmtCondition C = IS->getCond();
629-
TypeChecker::typeCheckStmtCondition(C, DC, diag::if_always_true);
630-
IS->setCond(C);
628+
TypeChecker::typeCheckConditionForStatement(IS, DC);
631629

632630
AddLabeledStmt ifNest(*this, IS);
633631

@@ -644,10 +642,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
644642
}
645643

646644
Stmt *visitGuardStmt(GuardStmt *GS) {
647-
StmtCondition C = GS->getCond();
648-
TypeChecker::typeCheckStmtCondition(C, DC, diag::guard_always_succeeds);
649-
GS->setCond(C);
650-
645+
TypeChecker::typeCheckConditionForStatement(GS, DC);
646+
651647
AddLabeledStmt ifNest(*this, GS);
652648

653649
Stmt *S = GS->getBody();
@@ -665,9 +661,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
665661
}
666662

667663
Stmt *visitWhileStmt(WhileStmt *WS) {
668-
StmtCondition C = WS->getCond();
669-
TypeChecker::typeCheckStmtCondition(C, DC, diag::while_always_true);
670-
WS->setCond(C);
664+
TypeChecker::typeCheckConditionForStatement(WS, DC);
671665

672666
AddLabeledStmt loopNest(*this, WS);
673667
Stmt *S = WS->getBody();

lib/Sema/TypeChecker.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,16 @@ void checkSwitchExhaustiveness(const SwitchStmt *stmt, const DeclContext *DC,
756756
/// \returns true if an error occurred, false otherwise.
757757
bool typeCheckCondition(Expr *&expr, DeclContext *dc);
758758

759-
/// Type check the given 'if' or 'while' statement condition, which
759+
/// Type check the given 'if', 'while', or 'guard' statement condition.
760+
///
761+
/// \param stmt The conditional statement to type-check, which will be modified
762+
/// in place.
763+
///
764+
/// \returns true if an error occurred, false otherwise.
765+
bool typeCheckConditionForStatement(LabeledConditionalStmt *stmt,
766+
DeclContext *dc);
767+
768+
/// Type check the given 'if', 'while', or 'guard' statement condition, which
760769
/// either converts an expression to a logic value or bind variables to the
761770
/// contents of an Optional.
762771
///

test/IDE/complete_skipbody.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,25 @@ struct MyStruct {
88
var y: Int { 1 }
99
}
1010

11-
func test(value: MyStruct) {
11+
func test(valueOptOpt: MyStruct??) {
1212

1313
let FORBIDDEN_localVar = 1
1414
let unrelated = FORBIDDEN_Struct()
15+
16+
let valueOpt = valueOptOpt!
17+
1518
guard let a = unrelated.FORBIDDEN_method() else {
1619
return
1720
}
1821

19-
_ = value.#^COMPLETE^#
22+
guard let value = valueOpt else {
23+
return
24+
}
25+
26+
if (value.x == 1) {
27+
let unrelated2 = FORBIDDEN_Struct()
28+
_ = value.#^COMPLETE^#
29+
}
2030
}
2131

2232
// CHECK: Begin completions, 3 items

0 commit comments

Comments
 (0)