Skip to content

Commit e780c19

Browse files
authored
Merge pull request swiftlang#30230 from DougGregor/function-builder-switch-checking
[Constraint system] Semantic checking for switch statements in function builders
2 parents 34353b8 + c0cf407 commit e780c19

File tree

5 files changed

+123
-42
lines changed

5 files changed

+123
-42
lines changed

lib/Sema/BuilderTransform.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,7 @@ class BuilderClosureRewriter
11141114
}
11151115

11161116
// Translate all of the cases.
1117+
bool limitExhaustivityChecks = false;
11171118
assert(target.kind == FunctionBuilderTarget::TemporaryVar);
11181119
auto temporaryVar = target.captured.first;
11191120
unsigned caseIndex = 0;
@@ -1124,10 +1125,18 @@ class BuilderClosureRewriter
11241125
temporaryVar, {target.captured.second[caseIndex]})))
11251126
return nullptr;
11261127

1128+
// Check restrictions on '@unknown'.
1129+
if (caseStmt->hasUnknownAttr()) {
1130+
checkUnknownAttrRestrictions(
1131+
cs.getASTContext(), caseStmt, /*fallthroughDest=*/nullptr,
1132+
limitExhaustivityChecks);
1133+
}
1134+
11271135
++caseIndex;
11281136
}
11291137

1130-
TypeChecker::checkSwitchExhaustiveness(switchStmt, dc, /*limited=*/false);
1138+
TypeChecker::checkSwitchExhaustiveness(
1139+
switchStmt, dc, limitExhaustivityChecks);
11311140

11321141
return switchStmt;
11331142
}

lib/Sema/TypeCheckStmt.cpp

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,40 +1035,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10351035
std::swap(*prevCaseDecls, *nextCaseDecls);
10361036
}
10371037

1038-
void checkUnknownAttrRestrictions(CaseStmt *caseBlock,
1039-
bool &limitExhaustivityChecks) {
1040-
if (caseBlock->getCaseLabelItems().size() != 1) {
1041-
assert(!caseBlock->getCaseLabelItems().empty() &&
1042-
"parser should not produce case blocks with no items");
1043-
getASTContext().Diags.diagnose(caseBlock->getLoc(),
1044-
diag::unknown_case_multiple_patterns)
1045-
.highlight(caseBlock->getCaseLabelItems()[1].getSourceRange());
1046-
limitExhaustivityChecks = true;
1047-
}
1048-
1049-
if (FallthroughDest != nullptr) {
1050-
if (!caseBlock->isDefault())
1051-
getASTContext().Diags.diagnose(caseBlock->getLoc(),
1052-
diag::unknown_case_must_be_last);
1053-
limitExhaustivityChecks = true;
1054-
}
1055-
1056-
const auto &labelItem = caseBlock->getCaseLabelItems().front();
1057-
if (labelItem.getGuardExpr() && !labelItem.isDefault()) {
1058-
getASTContext().Diags.diagnose(labelItem.getStartLoc(),
1059-
diag::unknown_case_where_clause)
1060-
.highlight(labelItem.getGuardExpr()->getSourceRange());
1061-
}
1062-
1063-
const Pattern *pattern =
1064-
labelItem.getPattern()->getSemanticsProvidingPattern();
1065-
if (!isa<AnyPattern>(pattern)) {
1066-
getASTContext().Diags.diagnose(labelItem.getStartLoc(),
1067-
diag::unknown_case_must_be_catchall)
1068-
.highlight(pattern->getSourceRange());
1069-
}
1070-
}
1071-
10721038
void checkFallthroughPatternBindingsAndTypes(CaseStmt *caseBlock,
10731039
CaseStmt *previousBlock) {
10741040
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
@@ -1236,7 +1202,9 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
12361202

12371203
// Check restrictions on '@unknown'.
12381204
if (caseBlock->hasUnknownAttr()) {
1239-
checkUnknownAttrRestrictions(caseBlock, limitExhaustivityChecks);
1205+
checkUnknownAttrRestrictions(
1206+
getASTContext(), caseBlock, FallthroughDest,
1207+
limitExhaustivityChecks);
12401208
}
12411209

12421210
// If the previous case fellthrough, similarly check that that case's
@@ -2067,23 +2035,81 @@ void TypeChecker::typeCheckTopLevelCodeDecl(TopLevelCodeDecl *TLCD) {
20672035
performTopLevelDeclDiagnostics(TLCD);
20682036
}
20692037

2038+
void swift::checkUnknownAttrRestrictions(
2039+
ASTContext &ctx, CaseStmt *caseBlock, CaseStmt *fallthroughDest,
2040+
bool &limitExhaustivityChecks) {
2041+
if (caseBlock->getCaseLabelItems().size() != 1) {
2042+
assert(!caseBlock->getCaseLabelItems().empty() &&
2043+
"parser should not produce case blocks with no items");
2044+
ctx.Diags.diagnose(caseBlock->getLoc(),
2045+
diag::unknown_case_multiple_patterns)
2046+
.highlight(caseBlock->getCaseLabelItems()[1].getSourceRange());
2047+
limitExhaustivityChecks = true;
2048+
}
2049+
2050+
if (fallthroughDest != nullptr) {
2051+
if (!caseBlock->isDefault())
2052+
ctx.Diags.diagnose(caseBlock->getLoc(),
2053+
diag::unknown_case_must_be_last);
2054+
limitExhaustivityChecks = true;
2055+
}
2056+
2057+
const auto &labelItem = caseBlock->getCaseLabelItems().front();
2058+
if (labelItem.getGuardExpr() && !labelItem.isDefault()) {
2059+
ctx.Diags.diagnose(labelItem.getStartLoc(),
2060+
diag::unknown_case_where_clause)
2061+
.highlight(labelItem.getGuardExpr()->getSourceRange());
2062+
}
2063+
2064+
const Pattern *pattern =
2065+
labelItem.getPattern()->getSemanticsProvidingPattern();
2066+
if (!isa<AnyPattern>(pattern)) {
2067+
ctx.Diags.diagnose(labelItem.getStartLoc(),
2068+
diag::unknown_case_must_be_catchall)
2069+
.highlight(pattern->getSourceRange());
2070+
}
2071+
}
2072+
20702073
void swift::bindSwitchCasePatternVars(CaseStmt *caseStmt) {
2071-
llvm::SmallDenseMap<Identifier, VarDecl *, 4> latestVars;
2074+
llvm::SmallDenseMap<Identifier, std::pair<VarDecl *, bool>, 4> latestVars;
20722075
auto recordVar = [&](VarDecl *var) {
20732076
if (!var->hasName())
20742077
return;
20752078

20762079
// If there is an existing variable with this name, set it as the
20772080
// parent of this new variable.
20782081
auto &entry = latestVars[var->getName()];
2079-
if (entry) {
2082+
if (entry.first) {
20802083
assert(!var->getParentVarDecl() ||
2081-
var->getParentVarDecl() == entry);
2082-
var->setParentVarDecl(entry);
2084+
var->getParentVarDecl() == entry.first);
2085+
var->setParentVarDecl(entry.first);
2086+
2087+
// Check for a mutability mismatch.
2088+
if (entry.second != var->isLet()) {
2089+
// Find the original declaration.
2090+
auto initialCaseVarDecl = entry.first;
2091+
while (auto parentVar = initialCaseVarDecl->getParentVarDecl())
2092+
initialCaseVarDecl = parentVar;
2093+
2094+
auto diag = var->diagnose(diag::mutability_mismatch_multiple_pattern_list,
2095+
var->isLet(), initialCaseVarDecl->isLet());
2096+
2097+
VarPattern *foundVP = nullptr;
2098+
var->getParentPattern()->forEachNode([&](Pattern *P) {
2099+
if (auto *VP = dyn_cast<VarPattern>(P))
2100+
if (VP->getSingleVar() == var)
2101+
foundVP = VP;
2102+
});
2103+
if (foundVP)
2104+
diag.fixItReplace(foundVP->getLoc(),
2105+
initialCaseVarDecl->isLet() ? "let" : "var");
2106+
}
2107+
} else {
2108+
entry.second = var->isLet();
20832109
}
20842110

20852111
// Record this variable as the latest with this name.
2086-
entry = var;
2112+
entry.first = var;
20872113
};
20882114

20892115
// Wire up the parent var decls for each variable that occurs within

lib/Sema/TypeChecker.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,12 @@ bool areGenericRequirementsSatisfied(const DeclContext *DC,
16541654
SubstitutionMap Substitutions,
16551655
bool isExtension);
16561656

1657+
/// Check for restrictions on the use of the @unknown attribute on a
1658+
/// case statement.
1659+
void checkUnknownAttrRestrictions(
1660+
ASTContext &ctx, CaseStmt *caseBlock, CaseStmt *fallthroughDest,
1661+
bool &limitExhaustivityChecks);
1662+
16571663
/// Bind all of the pattern variables that occur within a case statement and
16581664
/// all of its case items to their "parent" pattern variables, forming chains
16591665
/// of variables with the same name.

test/Constraints/function_builder.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,8 @@ testSwitch(getE(1))
593593
// CHECK-SAME: second("just 42")
594594
testSwitch(getE(2))
595595

596-
func testSwitchCombined(_ e: E) {
596+
func testSwitchCombined(_ eIn: E) {
597+
var e = eIn
597598
tuplify(true) { c in
598599
"testSwitchCombined"
599600
switch e {

test/Constraints/function_builder_diags.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,42 @@ struct TestConstraintGenerationErrors {
427427
}
428428
}
429429
}
430+
431+
// Check @unknown
432+
func testUnknownInSwitchSwitch(e: E) {
433+
tuplify(true) { c in
434+
"testSwitch"
435+
switch e {
436+
@unknown case .a: // expected-error{{'@unknown' is only supported for catch-all cases ("case _")}}
437+
"a"
438+
case .b(let i, let s?):
439+
i * 2
440+
s + "!"
441+
default:
442+
"nothing"
443+
}
444+
}
445+
}
446+
447+
// Check for mutability mismatches when there are multiple case items
448+
// referring to same-named variables.
449+
enum E3 {
450+
case a(Int, String)
451+
case b(String, Int)
452+
case c(String, Int)
453+
}
454+
455+
func testCaseMutabilityMismatches(e: E3) {
456+
tuplify(true) { c in
457+
"testSwitch"
458+
switch e {
459+
case .a(let x, var y),
460+
.b(let y, // expected-error{{'let' pattern binding must match previous 'var' pattern binding}}
461+
var x), // expected-error{{'var' pattern binding must match previous 'let' pattern binding}}
462+
.c(let y, // expected-error{{'let' pattern binding must match previous 'var' pattern binding}}
463+
var x): // expected-error{{'var' pattern binding must match previous 'let' pattern binding}}
464+
x
465+
y += "a"
466+
}
467+
}
468+
}

0 commit comments

Comments
 (0)