Skip to content

Commit 1d4b451

Browse files
authored
Merge pull request #84589 from MaxDesiatov/for-expressions
AST/Sema: `ForExpressions` experimental feature
2 parents b0a2bd6 + b545a28 commit 1d4b451

14 files changed

+215
-27
lines changed

include/swift/AST/Expr.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6498,16 +6498,20 @@ class KeyPathDotExpr : public Expr {
64986498
}
64996499
};
65006500

6501+
struct ForCollectionInit {
6502+
VarDecl *ForAccumulatorDecl;
6503+
PatternBindingDecl *ForAccumulatorBinding;
6504+
};
6505+
65016506
/// An expression that may wrap a statement which produces a single value.
65026507
class SingleValueStmtExpr : public Expr {
65036508
public:
6504-
enum class Kind {
6505-
If, Switch, Do, DoCatch
6506-
};
6509+
enum class Kind { If, Switch, Do, DoCatch, For };
65076510

65086511
private:
65096512
Stmt *S;
65106513
DeclContext *DC;
6514+
std::optional<ForCollectionInit> ForExpressionPreamble;
65116515

65126516
SingleValueStmtExpr(Stmt *S, DeclContext *DC)
65136517
: Expr(ExprKind::SingleValueStmt, /*isImplicit*/ true), S(S), DC(DC) {}
@@ -6572,6 +6576,14 @@ class SingleValueStmtExpr : public Expr {
65726576

65736577
SourceRange getSourceRange() const;
65746578

6579+
std::optional<ForCollectionInit> getForExpressionPreamble() const {
6580+
return this->ForExpressionPreamble;
6581+
}
6582+
6583+
void setForExpressionPreamble(ForCollectionInit newPreamble) {
6584+
this->ForExpressionPreamble = newPreamble;
6585+
}
6586+
65756587
static bool classof(const Expr *E) {
65766588
return E->getKind() == ExprKind::SingleValueStmt;
65776589
}

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ IDENTIFIER(alloc)
2929
IDENTIFIER(allocWithZone)
3030
IDENTIFIER(allZeros)
3131
IDENTIFIER(accumulated)
32+
IDENTIFIER(append)
3233
IDENTIFIER(ActorType)
3334
IDENTIFIER(Any)
3435
IDENTIFIER(ArrayLiteralElement)

include/swift/Basic/Features.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ EXPERIMENTAL_FEATURE(ThenStatements, false)
424424
/// Enable 'do' expressions.
425425
EXPERIMENTAL_FEATURE(DoExpressions, false)
426426

427+
/// Enable 'for' expressions.
428+
EXPERIMENTAL_FEATURE(ForExpressions, false)
429+
427430
/// Enable implicitly treating the last expression in a function, closure,
428431
/// and 'if'/'switch' expression as the result.
429432
EXPERIMENTAL_FEATURE(ImplicitLastExprResults, false)

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// This file defines the SyntacticElementTarget class.
13+
// This file defines the SyntacticElementTarget class (a unit of
14+
// type-checking).
1415
//
1516
//===----------------------------------------------------------------------===//
1617

@@ -59,8 +60,8 @@ struct PackIterationInfo {
5960
/// within the constraint system.
6061
using ForEachStmtInfo = TaggedUnion<SequenceIterationInfo, PackIterationInfo>;
6162

62-
/// Describes the target to which a constraint system's solution can be
63-
/// applied.
63+
/// Describes the target (a unit of type-checking) to which a constraint
64+
/// system's solution can be applied.
6465
class SyntacticElementTarget {
6566
public:
6667
enum class Kind {

lib/AST/ASTDumper.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4496,6 +4496,12 @@ class PrintExpr : public ExprVisitor<PrintExpr, void, Label>,
44964496
void visitSingleValueStmtExpr(SingleValueStmtExpr *E, Label label) {
44974497
printCommon(E, "single_value_stmt_expr", label);
44984498
printDeclContext(E);
4499+
if (auto preamble = E->getForExpressionPreamble()) {
4500+
printRec(preamble->ForAccumulatorDecl,
4501+
Label::optional("for_preamble_accumulator_decl"));
4502+
printRec(preamble->ForAccumulatorBinding,
4503+
Label::optional("for_preamble_accumulator_binding"));
4504+
}
44994505
printRec(E->getStmt(), &E->getDeclContext()->getASTContext(),
45004506
Label::optional("stmt"));
45014507
printFoot();

lib/AST/ASTWalker.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,16 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13851385
Expr *visitKeyPathDotExpr(KeyPathDotExpr *E) { return E; }
13861386

13871387
Expr *visitSingleValueStmtExpr(SingleValueStmtExpr *E) {
1388+
if (auto preamble = E->getForExpressionPreamble()) {
1389+
if (doIt(preamble->ForAccumulatorDecl)) {
1390+
return nullptr;
1391+
}
1392+
1393+
if (doIt(preamble->ForAccumulatorBinding)) {
1394+
return nullptr;
1395+
}
1396+
}
1397+
13881398
if (auto *S = doIt(E->getStmt())) {
13891399
E->setStmt(S);
13901400
} else {

lib/AST/Expr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,6 +2763,8 @@ SingleValueStmtExpr::Kind SingleValueStmtExpr::getStmtKind() const {
27632763
return Kind::Do;
27642764
case StmtKind::DoCatch:
27652765
return Kind::DoCatch;
2766+
case StmtKind::ForEach:
2767+
return Kind::For;
27662768
default:
27672769
llvm_unreachable("Unhandled kind!");
27682770
}
@@ -2781,6 +2783,9 @@ SingleValueStmtExpr::getBranches(SmallVectorImpl<Stmt *> &scratch) const {
27812783
return scratch;
27822784
case Kind::DoCatch:
27832785
return cast<DoCatchStmt>(getStmt())->getBranches(scratch);
2786+
case Kind::For:
2787+
scratch.push_back(cast<ForEachStmt>(getStmt())->getBody());
2788+
return scratch;
27842789
}
27852790
llvm_unreachable("Unhandled case in switch!");
27862791
}

lib/AST/FeatureSet.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ UNINTERESTING_FEATURE(RegionBasedIsolation)
116116
UNINTERESTING_FEATURE(PlaygroundExtendedCallbacks)
117117
UNINTERESTING_FEATURE(ThenStatements)
118118
UNINTERESTING_FEATURE(DoExpressions)
119+
UNINTERESTING_FEATURE(ForExpressions)
119120
UNINTERESTING_FEATURE(ImplicitLastExprResults)
120121
UNINTERESTING_FEATURE(RawLayout)
121122
UNINTERESTING_FEATURE(Embedded)

lib/SILGen/SILGenExpr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,6 +2515,17 @@ RValue RValueEmitter::visitEnumIsCaseExpr(EnumIsCaseExpr *E,
25152515

25162516
RValue RValueEmitter::visitSingleValueStmtExpr(SingleValueStmtExpr *E,
25172517
SGFContext C) {
2518+
if (E->getStmtKind() == SingleValueStmtExpr::Kind::For) {
2519+
auto *decl = E->getForExpressionPreamble()->ForAccumulatorDecl;
2520+
auto *binding = E->getForExpressionPreamble()->ForAccumulatorBinding;
2521+
SGF.visit(decl);
2522+
SGF.visit(binding);
2523+
SGF.emitStmt(E->getStmt());
2524+
2525+
return SGF.emitRValueForDecl(E, ConcreteDeclRef(decl), E->getType(),
2526+
AccessSemantics::Ordinary);
2527+
}
2528+
25182529
auto emitStmt = [&]() {
25192530
SGF.emitStmt(E->getStmt());
25202531

lib/Sema/CSSyntacticElement.cpp

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,8 +1292,12 @@ class SyntacticElementConstraintGenerator
12921292

12931293
// First check to make sure the ThenStmt is in a valid position.
12941294
SmallVector<ThenStmt *, 4> validThenStmts;
1295-
if (auto SVE = context.getAsSingleValueStmtExpr())
1295+
if (auto SVE = context.getAsSingleValueStmtExpr()) {
12961296
(void)SVE.get()->getThenStmts(validThenStmts);
1297+
if (SVE.get()->getStmtKind() == SingleValueStmtExpr::Kind::For) {
1298+
contextInfo = std::nullopt;
1299+
}
1300+
}
12971301

12981302
if (!llvm::is_contained(validThenStmts, thenStmt)) {
12991303
auto *thenLoc = cs.getConstraintLocator(thenStmt);
@@ -1488,8 +1492,37 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
14881492
auto &ctx = getASTContext();
14891493

14901494
auto *loc = getConstraintLocator(E);
1491-
Type resultTy = createTypeVariable(loc, /*options*/ 0);
1492-
setType(E, resultTy);
1495+
Type resultType = createTypeVariable(loc, /*options*/ 0);
1496+
setType(E, resultType);
1497+
1498+
if (E->getStmtKind() == SingleValueStmtExpr::Kind::For) {
1499+
auto *rrcProtocol =
1500+
ctx.getProtocol(KnownProtocolKind::RangeReplaceableCollection);
1501+
auto *sequenceProtocol = ctx.getProtocol(KnownProtocolKind::Sequence);
1502+
1503+
addConstraint(ConstraintKind::ConformsTo, resultType,
1504+
rrcProtocol->getDeclaredInterfaceType(), loc);
1505+
Type elementTypeVar = createTypeVariable(loc, /*options*/ 0);
1506+
Type elementType = DependentMemberType::get(
1507+
resultType, sequenceProtocol->getAssociatedType(ctx.Id_Element));
1508+
1509+
addConstraint(ConstraintKind::Bind, elementTypeVar, elementType, loc);
1510+
addConstraint(ConstraintKind::Defaultable, resultType,
1511+
ArraySliceType::get(elementTypeVar), loc);
1512+
1513+
auto *binding = E->getForExpressionPreamble()->ForAccumulatorBinding;
1514+
1515+
auto *initializer = binding->getInit(0);
1516+
auto target = SyntacticElementTarget::forInitialization(initializer, Type(),
1517+
binding, 0, false);
1518+
setTargetFor({binding, 0}, target);
1519+
1520+
if (generateConstraints(target)) {
1521+
return true;
1522+
}
1523+
1524+
addConstraint(ConstraintKind::Bind, getType(initializer), resultType, loc);
1525+
}
14931526

14941527
// Propagate the implied result kind from the if/switch expression itself
14951528
// into the branches.
@@ -1513,21 +1546,24 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
15131546
auto *loc = getConstraintLocator(
15141547
E, {LocatorPathElt::SingleValueStmtResult(idx), ctpElt});
15151548

1516-
ContextualTypeInfo info(resultTy, CTP_SingleValueStmtBranch, loc);
1549+
ContextualTypeInfo info(resultType, CTP_SingleValueStmtBranch, loc);
15171550
setContextualInfo(result, info);
15181551
}
15191552

15201553
TypeJoinExpr *join = nullptr;
1521-
if (branches.empty()) {
1522-
// If we only have statement branches, the expression is typed as Void. This
1523-
// should only be the case for 'if' and 'switch' statements that must be
1524-
// expressions that have branches that all end in a throw, and we'll warn
1525-
// that we've inferred Void.
1526-
addConstraint(ConstraintKind::Bind, resultTy, ctx.getVoidType(), loc);
1527-
} else {
1528-
// Otherwise, we join the result types for each of the branches.
1529-
join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr(
1530-
ctx, resultTy, E, AllocationArena::ConstraintSolver);
1554+
1555+
if (E->getStmtKind() != SingleValueStmtExpr::Kind::For) {
1556+
if (branches.empty()) {
1557+
// If we only have statement branches, the expression is typed as Void.
1558+
// This should only be the case for 'if' and 'switch' statements that must
1559+
// be expressions that have branches that all end in a throw, and we'll
1560+
// warn that we've inferred Void.
1561+
addConstraint(ConstraintKind::Bind, resultType, ctx.getVoidType(), loc);
1562+
} else {
1563+
// Otherwise, we join the result types for each of the branches.
1564+
join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr(
1565+
ctx, resultType, E, AllocationArena::ConstraintSolver);
1566+
}
15311567
}
15321568

15331569
// If this is an implied return in a closure, we need to account for the fact
@@ -1568,11 +1604,11 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
15681604
if (auto *closureTy = getClosureTypeIfAvailable(CE)) {
15691605
auto closureResultTy = closureTy->getResult();
15701606
auto *bindToClosure = Constraint::create(
1571-
*this, ConstraintKind::Bind, resultTy, closureResultTy, loc);
1607+
*this, ConstraintKind::Bind, resultType, closureResultTy, loc);
15721608
bindToClosure->setFavored();
15731609

1574-
auto *bindToVoid = Constraint::create(*this, ConstraintKind::Bind,
1575-
resultTy, ctx.getVoidType(), loc);
1610+
auto *bindToVoid = Constraint::create(
1611+
*this, ConstraintKind::Bind, resultType, ctx.getVoidType(), loc);
15761612

15771613
addDisjunctionConstraint({bindToClosure, bindToVoid}, loc);
15781614
}
@@ -2221,7 +2257,9 @@ class SyntacticElementSolutionApplication
22212257
// not the branch result type. This is necessary as there may be
22222258
// an additional conversion required for the branch.
22232259
auto target = solution.getTargetFor(thenStmt->getResult());
2224-
target->setExprConversionType(ty);
2260+
if (SVE.get()->getStmtKind() != SingleValueStmtExpr::Kind::For) {
2261+
target->setExprConversionType(ty);
2262+
}
22252263

22262264
auto *resultExpr = thenStmt->getResult();
22272265
if (auto newResultTarget = rewriter.rewriteTarget(*target))
@@ -2663,6 +2701,18 @@ bool ConstraintSystem::applySolutionToSingleValueStmt(
26632701
if (!stmt || application.hadError)
26642702
return true;
26652703

2704+
if (SVE->getStmtKind() == SingleValueStmtExpr::Kind::For) {
2705+
auto *binding = SVE->getForExpressionPreamble()->ForAccumulatorBinding;
2706+
auto target = getTargetFor({binding, 0}).value();
2707+
2708+
auto newTarget = rewriter.rewriteTarget(target);
2709+
if (!newTarget) {
2710+
return true;
2711+
}
2712+
2713+
binding->setInit(0, newTarget->getAsExpr());
2714+
}
2715+
26662716
SVE->setStmt(stmt);
26672717
return false;
26682718
}

0 commit comments

Comments
 (0)