Skip to content

Commit 0684c62

Browse files
authored
Merge pull request swiftlang#33278 from DougGregor/statement-checker-cleanups
2 parents cc5f3ae + f787e1b commit 0684c62

File tree

7 files changed

+130
-90
lines changed

7 files changed

+130
-90
lines changed

include/swift/AST/Stmt.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,10 @@ class CaseStmt final
10731073
return *CaseBodyVariables;
10741074
}
10751075

1076+
/// Find the next case statement within the same 'switch' or 'do-catch',
1077+
/// if there is one.
1078+
CaseStmt *findNextCaseStmt() const;
1079+
10761080
static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Case; }
10771081

10781082
size_t numTrailingObjects(OverloadToken<CaseLabelItem>) const {

lib/AST/ASTScopeLookup.cpp

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,7 @@ std::pair<CaseStmt *, CaseStmt *> ASTScopeImpl::lookupFallthroughSourceAndDest(
914914
ASTScopeAssert(innermost->getWasExpanded(),
915915
"If looking in a scope, it must have been expanded.");
916916

917-
// Look for the enclosing case statement and its 'switch' statement.
918-
CaseStmt *fallthroughSource = nullptr;
919-
SwitchStmt *switchStmt = nullptr;
917+
// Look for the enclosing case statement of a 'switch'.
920918
for (auto scope = innermost; scope && !scope->isLabeledStmtLookupTerminator();
921919
scope = scope->getParent().getPtrOrNull()) {
922920
// If we have a case statement, record it.
@@ -926,34 +924,12 @@ std::pair<CaseStmt *, CaseStmt *> ASTScopeImpl::lookupFallthroughSourceAndDest(
926924
// If we've found the first case statement of a switch, record it as the
927925
// fallthrough source. do-catch statements don't support fallthrough.
928926
if (auto caseStmt = dyn_cast<CaseStmt>(stmt.get())) {
929-
if (!fallthroughSource &&
930-
caseStmt->getParentKind() == CaseParentKind::Switch)
931-
fallthroughSource = caseStmt;
927+
if (caseStmt->getParentKind() == CaseParentKind::Switch)
928+
return { caseStmt, caseStmt->findNextCaseStmt() };
932929

933930
continue;
934931
}
935-
936-
// If we've found the first switch statement, record it and we're done.
937-
switchStmt = dyn_cast<SwitchStmt>(stmt.get());
938-
if (switchStmt)
939-
break;
940932
}
941933

942-
// If we don't have both a fallthrough source and a switch statement
943-
// enclosing it, the 'fallthrough' statement is ill-formed.
944-
if (!fallthroughSource || !switchStmt)
945-
return { nullptr, nullptr };
946-
947-
// Find this case in the list of cases for the switch. If we don't find it
948-
// here, it means that the case isn't directly nested inside the switch, so
949-
// the case and fallthrough are both ill-formed.
950-
auto caseIter = llvm::find(switchStmt->getCases(), fallthroughSource);
951-
if (caseIter == switchStmt->getCases().end())
952-
return { nullptr, nullptr };
953-
954-
// Move along to the next case. This is the fallthrough destination.
955-
++caseIter;
956-
auto fallthroughDest = caseIter == switchStmt->getCases().end() ? nullptr
957-
: *caseIter;
958-
return { fallthroughSource, fallthroughDest };
934+
return { nullptr, nullptr };
959935
}

lib/AST/Stmt.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,40 @@ CaseStmt *CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind,
458458
body, caseVarDecls, implicit, fallthroughStmt);
459459
}
460460

461+
namespace {
462+
463+
template<typename CaseIterator>
464+
CaseStmt *findNextCaseStmt(
465+
CaseIterator first, CaseIterator last, const CaseStmt *caseStmt) {
466+
for(auto caseIter = first; caseIter != last; ++caseIter) {
467+
if (*caseIter == caseStmt) {
468+
++caseIter;
469+
return caseIter == last ? nullptr : *caseIter;
470+
}
471+
}
472+
473+
return nullptr;
474+
}
475+
476+
}
477+
478+
CaseStmt *CaseStmt::findNextCaseStmt() const {
479+
auto parent = getParentStmt();
480+
if (!parent)
481+
return nullptr;
482+
483+
if (auto switchParent = dyn_cast<SwitchStmt>(parent)) {
484+
return ::findNextCaseStmt(
485+
switchParent->getCases().begin(), switchParent->getCases().end(),
486+
this);
487+
}
488+
489+
auto doCatchParent = cast<DoCatchStmt>(parent);
490+
return ::findNextCaseStmt(
491+
doCatchParent->getCatches().begin(), doCatchParent->getCatches().end(),
492+
this);
493+
}
494+
461495
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
462496
Expr *SubjectExpr,
463497
SourceLoc LBraceLoc,

lib/Sema/BuilderTransform.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,8 +1348,7 @@ class BuilderClosureRewriter
13481348
// Check restrictions on '@unknown'.
13491349
if (caseStmt->hasUnknownAttr()) {
13501350
checkUnknownAttrRestrictions(
1351-
cs.getASTContext(), caseStmt, /*fallthroughDest=*/nullptr,
1352-
limitExhaustivityChecks);
1351+
cs.getASTContext(), caseStmt, limitExhaustivityChecks);
13531352
}
13541353

13551354
++caseIndex;

lib/Sema/TypeCheckStmt.cpp

Lines changed: 72 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,50 @@ emitUnresolvedLabelDiagnostics(DiagnosticEngine &DE,
352352
}
353353
}
354354

355+
/// Find the target of a break or continue statement without a label.
356+
///
357+
/// \returns the target, if one was found, or \c nullptr if no such target
358+
/// exists.
359+
static LabeledStmt *findUnlabeledBreakOrContinueStmtTarget(
360+
ASTContext &ctx, SourceFile *sourceFile, SourceLoc loc,
361+
bool isContinue, DeclContext *dc,
362+
ArrayRef<LabeledStmt *> activeLabeledStmts) {
363+
for (auto labeledStmt : activeLabeledStmts) {
364+
// 'break' with no label looks through non-loop structures
365+
// except 'switch'.
366+
// 'continue' ignores non-loop structures.
367+
if (!labeledStmt->requiresLabelOnJump() &&
368+
(!isContinue || labeledStmt->isPossibleContinueTarget())) {
369+
return labeledStmt;
370+
}
371+
}
372+
373+
// If we're in a defer, produce a tailored diagnostic.
374+
if (isDefer(dc)) {
375+
ctx.Diags.diagnose(
376+
loc, diag::jump_out_of_defer, isContinue ? "continue": "break");
377+
return nullptr;
378+
}
379+
380+
// If we're dealing with an unlabeled break inside of an 'if' or 'do'
381+
// statement, produce a more specific error.
382+
if (!isContinue &&
383+
llvm::any_of(activeLabeledStmts,
384+
[&](Stmt *S) -> bool {
385+
return isa<IfStmt>(S) || isa<DoStmt>(S);
386+
})) {
387+
ctx.Diags.diagnose(
388+
loc, diag::unlabeled_break_outside_loop);
389+
return nullptr;
390+
}
391+
392+
// Otherwise produce a generic error.
393+
ctx.Diags.diagnose(
394+
loc,
395+
isContinue ? diag::continue_outside_loop : diag::break_outside_loop);
396+
return nullptr;
397+
}
398+
355399
/// Find the target of a break or continue statement.
356400
///
357401
/// \returns the target, if one was found, or \c nullptr if no such target
@@ -361,7 +405,6 @@ static LabeledStmt *findBreakOrContinueStmtTarget(
361405
SourceLoc loc, Identifier targetName, SourceLoc targetLoc,
362406
bool isContinue, DeclContext *dc,
363407
ArrayRef<LabeledStmt *> oldActiveLabeledStmts) {
364-
TopCollection<unsigned, LabeledStmt *> labelCorrections(3);
365408

366409
// Retrieve the active set of labeled statements.
367410
// FIXME: Once everything uses ASTScope lookup, \c oldActiveLabeledStmts
@@ -375,43 +418,37 @@ static LabeledStmt *findBreakOrContinueStmtTarget(
375418
oldActiveLabeledStmts.rbegin(), oldActiveLabeledStmts.rend());
376419
}
377420

378-
// Pick the nearest break target that matches the specified name.
421+
// Handle an unlabeled break separately; that's the easy case.
379422
if (targetName.empty()) {
380-
for (auto labeledStmt : activeLabeledStmts) {
381-
// 'break' with no label looks through non-loop structures
382-
// except 'switch'.
383-
// 'continue' ignores non-loop structures.
384-
if (!labeledStmt->requiresLabelOnJump() &&
385-
(!isContinue || labeledStmt->isPossibleContinueTarget())) {
386-
return labeledStmt;
387-
}
388-
}
389-
} else {
390-
// Scan inside out until we find something with the right label.
391-
for (auto labeledStmt : activeLabeledStmts) {
392-
if (targetName == labeledStmt->getLabelInfo().Name) {
393-
// Continue cannot be used to repeat switches, use fallthrough instead.
394-
if (isContinue && !labeledStmt->isPossibleContinueTarget()) {
395-
ctx.Diags.diagnose(
396-
loc, diag::continue_not_in_this_stmt,
397-
isa<SwitchStmt>(labeledStmt) ? "switch" : "if");
398-
return nullptr;
399-
}
423+
return findUnlabeledBreakOrContinueStmtTarget(
424+
ctx, sourceFile, loc, isContinue, dc, activeLabeledStmts);
425+
}
400426

401-
return labeledStmt;
427+
// Scan inside out until we find something with the right label.
428+
TopCollection<unsigned, LabeledStmt *> labelCorrections(3);
429+
for (auto labeledStmt : activeLabeledStmts) {
430+
if (targetName == labeledStmt->getLabelInfo().Name) {
431+
// Continue cannot be used to repeat switches, use fallthrough instead.
432+
if (isContinue && !labeledStmt->isPossibleContinueTarget()) {
433+
ctx.Diags.diagnose(
434+
loc, diag::continue_not_in_this_stmt,
435+
isa<SwitchStmt>(labeledStmt) ? "switch" : "if");
436+
return nullptr;
402437
}
403438

404-
unsigned distance =
405-
TypeChecker::getCallEditDistance(
406-
DeclNameRef(targetName),
407-
labeledStmt->getLabelInfo().Name,
408-
TypeChecker::UnreasonableCallEditDistance);
409-
if (distance < TypeChecker::UnreasonableCallEditDistance)
410-
labelCorrections.insert(distance, std::move(labeledStmt));
439+
return labeledStmt;
411440
}
412-
labelCorrections.filterMaxScoreRange(
413-
TypeChecker::MaxCallEditDistanceFromBestCandidate);
441+
442+
unsigned distance =
443+
TypeChecker::getCallEditDistance(
444+
DeclNameRef(targetName),
445+
labeledStmt->getLabelInfo().Name,
446+
TypeChecker::UnreasonableCallEditDistance);
447+
if (distance < TypeChecker::UnreasonableCallEditDistance)
448+
labelCorrections.insert(distance, std::move(labeledStmt));
414449
}
450+
labelCorrections.filterMaxScoreRange(
451+
TypeChecker::MaxCallEditDistanceFromBestCandidate);
415452

416453
// If we're in a defer, produce a tailored diagnostic.
417454
if (isDefer(dc)) {
@@ -420,26 +457,6 @@ static LabeledStmt *findBreakOrContinueStmtTarget(
420457
return nullptr;
421458
}
422459

423-
if (targetName.empty()) {
424-
// If we're dealing with an unlabeled break inside of an 'if' or 'do'
425-
// statement, produce a more specific error.
426-
if (!isContinue &&
427-
llvm::any_of(activeLabeledStmts,
428-
[&](Stmt *S) -> bool {
429-
return isa<IfStmt>(S) || isa<DoStmt>(S);
430-
})) {
431-
ctx.Diags.diagnose(
432-
loc, diag::unlabeled_break_outside_loop);
433-
return nullptr;
434-
}
435-
436-
// Otherwise produce a generic error.
437-
ctx.Diags.diagnose(
438-
loc,
439-
isContinue ? diag::continue_outside_loop : diag::break_outside_loop);
440-
return nullptr;
441-
}
442-
443460
// Provide potential corrections for an incorrect label.
444461
emitUnresolvedLabelDiagnostics(
445462
ctx.Diags, targetLoc, targetName, labelCorrections);
@@ -1316,8 +1333,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
13161333
assert(parentKind == CaseParentKind::Switch &&
13171334
"'@unknown' can only appear on switch cases");
13181335
checkUnknownAttrRestrictions(
1319-
getASTContext(), caseBlock, FallthroughDest,
1320-
limitExhaustivityChecks);
1336+
getASTContext(), caseBlock, limitExhaustivityChecks);
13211337
}
13221338

13231339
Stmt *body = caseBlock->getBody();
@@ -2237,8 +2253,9 @@ void TypeChecker::typeCheckTopLevelCodeDecl(TopLevelCodeDecl *TLCD) {
22372253
}
22382254

22392255
void swift::checkUnknownAttrRestrictions(
2240-
ASTContext &ctx, CaseStmt *caseBlock, CaseStmt *fallthroughDest,
2256+
ASTContext &ctx, CaseStmt *caseBlock,
22412257
bool &limitExhaustivityChecks) {
2258+
CaseStmt *fallthroughDest = caseBlock->findNextCaseStmt();
22422259
if (caseBlock->getCaseLabelItems().size() != 1) {
22432260
assert(!caseBlock->getCaseLabelItems().empty() &&
22442261
"parser should not produce case blocks with no items");

lib/Sema/TypeChecker.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,8 +1368,7 @@ bool areGenericRequirementsSatisfied(const DeclContext *DC,
13681368
/// Check for restrictions on the use of the @unknown attribute on a
13691369
/// case statement.
13701370
void checkUnknownAttrRestrictions(
1371-
ASTContext &ctx, CaseStmt *caseBlock, CaseStmt *fallthroughDest,
1372-
bool &limitExhaustivityChecks);
1371+
ASTContext &ctx, CaseStmt *caseBlock, bool &limitExhaustivityChecks);
13731372

13741373
/// Bind all of the pattern variables that occur within a case statement and
13751374
/// all of its case items to their "parent" pattern variables, forming chains

test/Constraints/function_builder_diags.swift

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,16 +433,27 @@ struct TestConstraintGenerationErrors {
433433

434434
// Check @unknown
435435
func testUnknownInSwitchSwitch(e: E) {
436-
tuplify(true) { c in
436+
tuplify(true) { c in
437437
"testSwitch"
438438
switch e {
439-
@unknown case .a: // expected-error{{'@unknown' is only supported for catch-all cases ("case _")}}
440-
"a"
441439
case .b(let i, let s?):
442440
i * 2
443441
s + "!"
444442
default:
445443
"nothing"
444+
@unknown case .a: // expected-error{{'@unknown' is only supported for catch-all cases ("case _")}}
445+
// expected-error@-1{{'case' blocks cannot appear after the 'default' block of a 'switch'}}
446+
"a"
447+
}
448+
}
449+
450+
tuplify(true) { c in
451+
"testSwitch"
452+
switch e {
453+
@unknown case _: // expected-error{{'@unknown' can only be applied to the last case in a switch}}
454+
"a"
455+
default:
456+
"default"
446457
}
447458
}
448459
}

0 commit comments

Comments
 (0)