@@ -446,6 +446,103 @@ static LabeledStmt *findBreakOrContinueStmtTarget(
446
446
return nullptr ;
447
447
}
448
448
449
+ // / Type check the given 'if', 'while', or 'guard' statement condition.
450
+ // /
451
+ // / \param stmt The conditional statement to type-check, which will be modified
452
+ // / in place.
453
+ // /
454
+ // / \returns true if an error occurred, false otherwise.
455
+ static bool typeCheckConditionForStatement (LabeledConditionalStmt *stmt,
456
+ DeclContext *dc) {
457
+ auto &Context = dc->getASTContext ();
458
+ bool hadError = false ;
459
+ bool hadAnyFalsable = false ;
460
+ auto cond = stmt->getCond ();
461
+ for (auto &elt : cond) {
462
+ if (elt.getKind () == StmtConditionElement::CK_Availability) {
463
+ hadAnyFalsable = true ;
464
+ continue ;
465
+ }
466
+
467
+ if (auto E = elt.getBooleanOrNull ()) {
468
+ assert (!E->getType () && " the bool condition is already type checked" );
469
+ hadError |= TypeChecker::typeCheckCondition (E, dc);
470
+ elt.setBoolean (E);
471
+ hadAnyFalsable = true ;
472
+ continue ;
473
+ }
474
+ assert (elt.getKind () != StmtConditionElement::CK_Boolean);
475
+
476
+ // This is cleanup goop run on the various paths where type checking of the
477
+ // pattern binding fails.
478
+ auto typeCheckPatternFailed = [&] {
479
+ hadError = true ;
480
+ elt.getPattern ()->setType (ErrorType::get (Context));
481
+ elt.getInitializer ()->setType (ErrorType::get (Context));
482
+
483
+ elt.getPattern ()->forEachVariable ([&](VarDecl *var) {
484
+ // Don't change the type of a variable that we've been able to
485
+ // compute a type for.
486
+ if (var->hasInterfaceType () && !var->isInvalid ())
487
+ return ;
488
+ var->setInvalid ();
489
+ });
490
+ };
491
+
492
+ // Resolve the pattern.
493
+ assert (!elt.getPattern ()->hasType () &&
494
+ " the pattern binding condition is already type checked" );
495
+ auto *pattern = TypeChecker::resolvePattern (elt.getPattern (), dc,
496
+ /* isStmtCondition*/ true );
497
+ if (!pattern) {
498
+ typeCheckPatternFailed ();
499
+ continue ;
500
+ }
501
+ elt.setPattern (pattern);
502
+
503
+ // Check the pattern, it allows unspecified types because the pattern can
504
+ // provide type information.
505
+ auto contextualPattern = ContextualPattern::forRawPattern (pattern, dc);
506
+ Type patternType = TypeChecker::typeCheckPattern (contextualPattern);
507
+ if (patternType->hasError ()) {
508
+ typeCheckPatternFailed ();
509
+ continue ;
510
+ }
511
+
512
+ // If the pattern didn't get a type, it's because we ran into some
513
+ // unknown types along the way. We'll need to check the initializer.
514
+ auto init = elt.getInitializer ();
515
+ hadError |= TypeChecker::typeCheckBinding (pattern, init, dc, patternType);
516
+ elt.setPattern (pattern);
517
+ elt.setInitializer (init);
518
+ hadAnyFalsable |= pattern->isRefutablePattern ();
519
+ }
520
+
521
+ // If the binding is not refutable, and there *is* an else, reject it as
522
+ // unreachable.
523
+ if (!hadAnyFalsable && !hadError) {
524
+ auto &diags = dc->getASTContext ().Diags ;
525
+ Diag<> msg = diag::invalid_diagnostic;
526
+ switch (stmt->getKind ()) {
527
+ case StmtKind::If:
528
+ msg = diag::if_always_true;
529
+ break ;
530
+ case StmtKind::While:
531
+ msg = diag::while_always_true;
532
+ break ;
533
+ case StmtKind::Guard:
534
+ msg = diag::guard_always_succeeds;
535
+ break ;
536
+ default :
537
+ llvm_unreachable (" unknown LabeledConditionalStmt kind" );
538
+ }
539
+ diags.diagnose (cond[0 ].getStartLoc (), msg);
540
+ }
541
+
542
+ stmt->setCond (cond);
543
+ return false ;
544
+ }
545
+
449
546
namespace {
450
547
class StmtChecker : public StmtVisitor <StmtChecker, Stmt*> {
451
548
public:
@@ -785,7 +882,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
785
882
}
786
883
787
884
Stmt *visitIfStmt (IfStmt *IS) {
788
- TypeChecker:: typeCheckConditionForStatement (IS, DC);
885
+ typeCheckConditionForStatement (IS, DC);
789
886
790
887
AddLabeledStmt ifNest (*this , IS);
791
888
@@ -802,7 +899,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
802
899
}
803
900
804
901
Stmt *visitGuardStmt (GuardStmt *GS) {
805
- TypeChecker:: typeCheckConditionForStatement (GS, DC);
902
+ typeCheckConditionForStatement (GS, DC);
806
903
807
904
Stmt *S = GS->getBody ();
808
905
typeCheckStmt (S);
@@ -819,7 +916,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
819
916
}
820
917
821
918
Stmt *visitWhileStmt (WhileStmt *WS) {
822
- TypeChecker:: typeCheckConditionForStatement (WS, DC);
919
+ typeCheckConditionForStatement (WS, DC);
823
920
824
921
AddLabeledStmt loopNest (*this , WS);
825
922
Stmt *S = WS->getBody ();
0 commit comments