Skip to content

Commit b7860ea

Browse files
committed
[TypeChecker] Split for-in sequence into parsed and type-checked versions
1 parent 5f0dcb5 commit b7860ea

17 files changed

+64
-49
lines changed

include/swift/AST/Stmt.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ class ForEachStmt : public LabeledStmt {
743743

744744
// Set by Sema:
745745
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
746-
VarDecl *iteratorVar = nullptr;
746+
PatternBindingDecl *iteratorVar = nullptr;
747747
Expr *nextCall = nullptr;
748748
OpaqueValueExpr *elementExpr = nullptr;
749749
Expr *convertElementExpr = nullptr;
@@ -759,8 +759,8 @@ class ForEachStmt : public LabeledStmt {
759759
setPattern(Pat);
760760
}
761761

762-
void setIteratorVar(VarDecl *var) { iteratorVar = var; }
763-
VarDecl *getIteratorVar() const { return iteratorVar; }
762+
void setIteratorVar(PatternBindingDecl *var) { iteratorVar = var; }
763+
PatternBindingDecl *getIteratorVar() const { return iteratorVar; }
764764

765765
void setNextCall(Expr *next) { nextCall = next; }
766766
Expr *getNextCall() const { return nextCall; }
@@ -802,8 +802,12 @@ class ForEachStmt : public LabeledStmt {
802802
/// by this foreach loop, as it was written in the source code and
803803
/// subsequently type-checked. To determine the semantic behavior of this
804804
/// expression to extract a range, use \c getRangeInit().
805-
Expr *getSequence() const { return Sequence; }
806-
void setSequence(Expr *S) { Sequence = S; }
805+
Expr *getParsedSequence() const { return Sequence; }
806+
void setParsedSequence(Expr *S) { Sequence = S; }
807+
808+
/// Type-checked version of the sequence or nullptr if this statement
809+
/// yet to be type-checked.
810+
Expr *getTypeCheckedSequence() const;
807811

808812
/// getBody - Retrieve the body of the loop.
809813
BraceStmt *getBody() const { return Body; }

include/swift/Sema/ConstraintSystem.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ struct ForEachStmtInfo {
10111011
Type initType;
10121012

10131013
/// Implicit `$iterator = <sequence>.makeIterator()`
1014-
VarDecl *makeIteratorVar;
1014+
PatternBindingDecl *makeIteratorVar;
10151015

10161016
/// Implicit `$iterator.next()` call.
10171017
Expr *nextCall;
@@ -2434,7 +2434,7 @@ class SolutionApplicationTarget {
24342434
case Kind::forEachStmt:
24352435
auto *stmt = forEachStmt.stmt;
24362436
SourceLoc startLoc = stmt->getForLoc();
2437-
SourceLoc endLoc = stmt->getSequence()->getEndLoc();
2437+
SourceLoc endLoc = stmt->getParsedSequence()->getEndLoc();
24382438

24392439
if (auto *whereExpr = stmt->getWhere()) {
24402440
endLoc = whereExpr->getEndLoc();

lib/AST/ASTDumper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
15721572
}
15731573
printRec(S->getPattern());
15741574
OS << '\n';
1575-
printRec(S->getSequence());
1575+
printRec(S->getParsedSequence());
15761576
OS << '\n';
15771577
if (S->getIteratorVar()) {
15781578
printRec(S->getIteratorVar());

lib/AST/ASTScopeCreation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ void SwitchStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
994994

995995
void ForEachStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
996996
ScopeCreator &scopeCreator) {
997-
scopeCreator.addToScopeTree(stmt->getSequence(), this);
997+
scopeCreator.addToScopeTree(stmt->getParsedSequence(), this);
998998

999999
// Add a child describing the scope of the pattern.
10001000
// In error cases such as:

lib/AST/ASTWalker.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,11 +1590,28 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
15901590

15911591
// The iterator decl is built directly on top of the sequence
15921592
// expression, so don't visit both.
1593-
if (Expr *Sequence = S->getSequence()) {
1594-
if ((Sequence = doIt(Sequence)))
1595-
S->setSequence(Sequence);
1596-
else
1597-
return nullptr;
1593+
//
1594+
// If for-in is already type-checked, the type-checked version
1595+
// of the sequence is going to be visited as part of `iteratorVar`.
1596+
if (S->getTypeCheckedSequence()) {
1597+
if (auto IteratorVar = S->getIteratorVar()) {
1598+
if (doIt(IteratorVar))
1599+
return nullptr;
1600+
}
1601+
1602+
if (auto NextCall = S->getNextCall()) {
1603+
if ((NextCall = doIt(NextCall)))
1604+
S->setNextCall(NextCall);
1605+
else
1606+
return nullptr;
1607+
}
1608+
} else {
1609+
if (Expr *Sequence = S->getParsedSequence()) {
1610+
if ((Sequence = doIt(Sequence)))
1611+
S->setParsedSequence(Sequence);
1612+
else
1613+
return nullptr;
1614+
}
15981615
}
15991616

16001617
if (Expr *Where = S->getWhere()) {
@@ -1611,18 +1628,6 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
16111628
return nullptr;
16121629
}
16131630

1614-
if (auto IteratorVar = S->getIteratorVar()) {
1615-
if (doIt(IteratorVar))
1616-
return nullptr;
1617-
}
1618-
1619-
if (auto NextCall = S->getNextCall()) {
1620-
if ((NextCall = doIt(NextCall)))
1621-
S->setNextCall(NextCall);
1622-
else
1623-
return nullptr;
1624-
}
1625-
16261631
if (Stmt *Body = S->getBody()) {
16271632
if ((Body = doIt(Body)))
16281633
S->setBody(cast<BraceStmt>(Body));

lib/AST/NameLookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3110,7 +3110,7 @@ void FindLocalVal::visitForEachStmt(ForEachStmt *S) {
31103110
if (!isReferencePointInRange(S->getSourceRange()))
31113111
return;
31123112
visit(S->getBody());
3113-
if (!isReferencePointInRange(S->getSequence()->getSourceRange()))
3113+
if (!isReferencePointInRange(S->getParsedSequence()->getSourceRange()))
31143114
checkPattern(S->getPattern(), DeclVisibilityKind::LocalVariable);
31153115
}
31163116

lib/AST/Stmt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ void ForEachStmt::setPattern(Pattern *p) {
301301
Pat->markOwnedByStatement(this);
302302
}
303303

304+
Expr *ForEachStmt::getTypeCheckedSequence() const {
305+
return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr;
306+
}
307+
304308
DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
305309
SourceLoc doLoc, Stmt *body,
306310
ArrayRef<CaseStmt *> catches,

lib/IDE/ExprContextAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ class ExprContextAnalyzer {
10921092
break;
10931093
}
10941094
case StmtKind::ForEach:
1095-
if (auto SEQ = cast<ForEachStmt>(Parent)->getSequence()) {
1095+
if (auto SEQ = cast<ForEachStmt>(Parent)->getParsedSequence()) {
10961096
if (containsTarget(SEQ)) {
10971097
recordPossibleType(Context.getSequenceType());
10981098
}

lib/IDE/Formatting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2246,7 +2246,7 @@ class FormatWalker : public ASTWalker {
22462246
if (Range.isValid() && overlapsTarget(Range))
22472247
return IndentContext {ForLoc, !OutdentChecker::hasOutdent(SM, P)};
22482248
}
2249-
if (auto *E = FS->getSequence()) {
2249+
if (auto *E = FS->getParsedSequence()) {
22502250
SourceRange Range = FS->getInLoc();
22512251
widenOrSet(Range, E->getSourceRange());
22522252
if (Range.isValid() && isTargetContext(Range)) {

lib/IDE/SyntaxModel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,8 @@ std::pair<bool, Stmt *> ModelASTWalker::walkToStmtPre(Stmt *S) {
747747
charSourceRangeFromSourceRange(SM, ElemRange));
748748
}
749749
}
750-
if (ForEachS->getSequence())
751-
addExprElem(SyntaxStructureElementKind::Expr, ForEachS->getSequence(),SN);
750+
if (auto *S = ForEachS->getParsedSequence())
751+
addExprElem(SyntaxStructureElementKind::Expr, S, SN);
752752
pushStructureNode(SN, S);
753753

754754
} else if (auto *WhileS = dyn_cast<WhileStmt>(S)) {

0 commit comments

Comments
 (0)