Skip to content

Commit 09fd729

Browse files
committed
[Statement checker] Factor out 'fallthrough' checking a bit more.
1 parent 57d5c4d commit 09fd729

File tree

1 file changed

+87
-74
lines changed

1 file changed

+87
-74
lines changed

lib/Sema/TypeCheckStmt.cpp

Lines changed: 87 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,92 @@ static bool typeCheckConditionForStatement(LabeledConditionalStmt *stmt,
560560
return false;
561561
}
562562

563+
/// Verify that the pattern bindings for the cases that we're falling through
564+
/// from and to are equivalent.
565+
static void checkFallthroughPatternBindingsAndTypes(
566+
ASTContext &ctx,
567+
CaseStmt *caseBlock, CaseStmt *previousBlock,
568+
FallthroughStmt *fallthrough) {
569+
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
570+
SmallVector<VarDecl *, 4> vars;
571+
firstPattern->collectVariables(vars);
572+
573+
// We know that the typechecker has already guaranteed that all of
574+
// the case label items in the fallthrough have the same var
575+
// decls. So if we match against the case body var decls,
576+
// transitively we will match all of the other case label items in
577+
// the fallthrough destination as well.
578+
auto previousVars = previousBlock->getCaseBodyVariablesOrEmptyArray();
579+
for (auto *expected : vars) {
580+
bool matched = false;
581+
if (!expected->hasName())
582+
continue;
583+
584+
for (auto *previous : previousVars) {
585+
if (!previous->hasName() ||
586+
expected->getName() != previous->getName()) {
587+
continue;
588+
}
589+
590+
if (!previous->getType()->isEqual(expected->getType())) {
591+
ctx.Diags.diagnose(
592+
previous->getLoc(), diag::type_mismatch_fallthrough_pattern_list,
593+
previous->getType(), expected->getType());
594+
previous->setInvalid();
595+
expected->setInvalid();
596+
}
597+
598+
// Ok, we found our match. Make the previous fallthrough statement var
599+
// decl our parent var decl.
600+
expected->setParentVarDecl(previous);
601+
matched = true;
602+
break;
603+
}
604+
605+
if (!matched) {
606+
ctx.Diags.diagnose(
607+
fallthrough->getLoc(), diag::fallthrough_into_case_with_var_binding,
608+
expected->getName());
609+
}
610+
}
611+
}
612+
613+
/// Check the correctness of a 'fallthrough' statement.
614+
///
615+
/// \returns true if an error occurred.
616+
static bool checkFallthroughStmt(
617+
DeclContext *dc, FallthroughStmt *stmt,
618+
CaseStmt *oldFallthroughSource, CaseStmt *oldFallthroughDest) {
619+
CaseStmt *fallthroughSource;
620+
CaseStmt *fallthroughDest;
621+
ASTContext &ctx = dc->getASTContext();
622+
if (ctx.LangOpts.EnableASTScopeLookup) {
623+
auto sourceFile = dc->getParentSourceFile();
624+
std::tie(fallthroughSource, fallthroughDest) =
625+
ASTScope::lookupFallthroughSourceAndDest(sourceFile, stmt->getLoc());
626+
assert(fallthroughSource == oldFallthroughSource);
627+
assert(fallthroughDest == oldFallthroughDest);
628+
} else {
629+
fallthroughSource = oldFallthroughSource;
630+
fallthroughDest = oldFallthroughDest;
631+
}
632+
633+
if (!fallthroughSource) {
634+
ctx.Diags.diagnose(stmt->getLoc(), diag::fallthrough_outside_switch);
635+
return true;
636+
}
637+
if (!fallthroughDest) {
638+
ctx.Diags.diagnose(stmt->getLoc(), diag::fallthrough_from_last_case);
639+
return true;
640+
}
641+
stmt->setFallthroughSource(fallthroughSource);
642+
stmt->setFallthroughDest(fallthroughDest);
643+
644+
checkFallthroughPatternBindingsAndTypes(
645+
ctx, fallthroughDest, fallthroughSource, stmt);
646+
return false;
647+
}
648+
563649
namespace {
564650
class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
565651
public:
@@ -992,34 +1078,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
9921078
}
9931079

9941080
Stmt *visitFallthroughStmt(FallthroughStmt *S) {
995-
CaseStmt *fallthroughSource;
996-
CaseStmt *fallthroughDest;
997-
if (getASTContext().LangOpts.EnableASTScopeLookup) {
998-
auto sourceFile = DC->getParentSourceFile();
999-
std::tie(fallthroughSource, fallthroughDest) =
1000-
ASTScope::lookupFallthroughSourceAndDest(sourceFile, S->getLoc());
1001-
assert(fallthroughSource == FallthroughSource);
1002-
assert(fallthroughDest == FallthroughDest);
1003-
} else {
1004-
fallthroughSource = FallthroughSource;
1005-
fallthroughDest = FallthroughDest;
1006-
}
1007-
1008-
if (!fallthroughSource) {
1009-
getASTContext().Diags.diagnose(S->getLoc(),
1010-
diag::fallthrough_outside_switch);
1081+
if (checkFallthroughStmt(DC, S, FallthroughSource, FallthroughDest))
10111082
return nullptr;
1012-
}
1013-
if (!fallthroughDest) {
1014-
getASTContext().Diags.diagnose(S->getLoc(),
1015-
diag::fallthrough_from_last_case);
1016-
return nullptr;
1017-
}
1018-
S->setFallthroughSource(fallthroughSource);
1019-
S->setFallthroughDest(fallthroughDest);
1020-
1021-
checkFallthroughPatternBindingsAndTypes(
1022-
fallthroughDest, fallthroughSource, S);
10231083

10241084
return S;
10251085
}
@@ -1166,53 +1226,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
11661226
std::swap(*prevCaseDecls, *nextCaseDecls);
11671227
}
11681228

1169-
void checkFallthroughPatternBindingsAndTypes(CaseStmt *caseBlock,
1170-
CaseStmt *previousBlock,
1171-
FallthroughStmt *fallthrough) {
1172-
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
1173-
SmallVector<VarDecl *, 4> vars;
1174-
firstPattern->collectVariables(vars);
1175-
1176-
// We know that the typechecker has already guaranteed that all of
1177-
// the case label items in the fallthrough have the same var
1178-
// decls. So if we match against the case body var decls,
1179-
// transitively we will match all of the other case label items in
1180-
// the fallthrough destination as well.
1181-
auto previousVars = previousBlock->getCaseBodyVariablesOrEmptyArray();
1182-
for (auto *expected : vars) {
1183-
bool matched = false;
1184-
if (!expected->hasName())
1185-
continue;
1186-
1187-
for (auto *previous : previousVars) {
1188-
if (!previous->hasName() ||
1189-
expected->getName() != previous->getName()) {
1190-
continue;
1191-
}
1192-
1193-
if (!previous->getType()->isEqual(expected->getType())) {
1194-
getASTContext().Diags.diagnose(previous->getLoc(),
1195-
diag::type_mismatch_fallthrough_pattern_list,
1196-
previous->getType(), expected->getType());
1197-
previous->setInvalid();
1198-
expected->setInvalid();
1199-
}
1200-
1201-
// Ok, we found our match. Make the previous fallthrough statement var
1202-
// decl our parent var decl.
1203-
expected->setParentVarDecl(previous);
1204-
matched = true;
1205-
break;
1206-
}
1207-
1208-
if (!matched) {
1209-
getASTContext().Diags.diagnose(fallthrough->getLoc(),
1210-
diag::fallthrough_into_case_with_var_binding,
1211-
expected->getName());
1212-
}
1213-
}
1214-
}
1215-
12161229
template <typename Iterator>
12171230
void checkSiblingCaseStmts(Iterator casesBegin, Iterator casesEnd,
12181231
CaseParentKind parentKind,

0 commit comments

Comments
 (0)