Skip to content

Commit 141f3e7

Browse files
committed
[Constraint system] Expand SolutionApplicationTarget to StmtConditions.
Handle StmtCondition as part of SolutionApplicationTarget, so we can generate constraints from it and rewrite directly as part of a solution, rather than open-coding the operation in the function builder transform.
1 parent 4830c48 commit 141f3e7

File tree

3 files changed

+136
-80
lines changed

3 files changed

+136
-80
lines changed

lib/Sema/BuilderTransform.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -920,38 +920,9 @@ class BuilderClosureRewriter
920920

921921
Stmt *visitIfStmt(IfStmt *ifStmt, FunctionBuilderTarget target) {
922922
// Rewrite the condition.
923-
auto condition = ifStmt->getCond();
924-
for (auto &condElement : condition) {
925-
switch (condElement.getKind()) {
926-
case StmtConditionElement::CK_Availability:
927-
continue;
928-
929-
case StmtConditionElement::CK_Boolean: {
930-
auto condExpr = condElement.getBoolean();
931-
auto finalCondExpr = rewriteExpr(condExpr);
932-
933-
// Load the condition if needed.
934-
if (finalCondExpr->getType()->hasLValueType()) {
935-
finalCondExpr = TypeChecker::addImplicitLoadExpr(ctx, finalCondExpr);
936-
}
937-
938-
condElement.setBoolean(finalCondExpr);
939-
continue;
940-
}
941-
942-
case StmtConditionElement::CK_PatternBinding: {
943-
ConstraintSystem &cs = solution.getConstraintSystem();
944-
auto target = *cs.getStmtConditionTarget(&condElement);
945-
auto resolvedTarget = rewriteTarget(target);
946-
if (resolvedTarget) {
947-
condElement.setInitializer(resolvedTarget->getAsExpr());
948-
condElement.setPattern(resolvedTarget->getInitializationPattern());
949-
}
950-
continue;
951-
}
952-
}
953-
}
954-
ifStmt->setCond(condition);
923+
if (auto condition = rewriteTarget(
924+
SolutionApplicationTarget(ifStmt->getCond(), dc)))
925+
ifStmt->setCond(*condition->getAsStmtCondition());
955926

956927
assert(target.kind == FunctionBuilderTarget::TemporaryVar);
957928
auto temporaryVar = target.captured.first;

lib/Sema/CSApply.cpp

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7092,6 +7092,48 @@ bool swift::exprNeedsParensAfterAddingNilCoalescing(DeclContext *DC,
70927092
}
70937093

70947094
namespace {
7095+
class SetExprTypes : public ASTWalker {
7096+
const Solution &solution;
7097+
7098+
public:
7099+
explicit SetExprTypes(const Solution &solution)
7100+
: solution(solution) {}
7101+
7102+
Expr *walkToExprPost(Expr *expr) override {
7103+
auto &cs = solution.getConstraintSystem();
7104+
auto exprType = cs.getType(expr);
7105+
exprType = solution.simplifyType(exprType);
7106+
// assert((!expr->getType() || expr->getType()->isEqual(exprType)) &&
7107+
// "Mismatched types!");
7108+
assert(!exprType->hasTypeVariable() &&
7109+
"Should not write type variable into expression!");
7110+
expr->setType(exprType);
7111+
7112+
if (auto kp = dyn_cast<KeyPathExpr>(expr)) {
7113+
for (auto i : indices(kp->getComponents())) {
7114+
Type componentType;
7115+
if (cs.hasType(kp, i)) {
7116+
componentType = solution.simplifyType(cs.getType(kp, i));
7117+
assert(!componentType->hasTypeVariable() &&
7118+
"Should not write type variable into key-path component");
7119+
}
7120+
7121+
kp->getMutableComponents()[i].setComponentType(componentType);
7122+
}
7123+
}
7124+
7125+
return expr;
7126+
}
7127+
7128+
/// Ignore statements.
7129+
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
7130+
return { false, stmt };
7131+
}
7132+
7133+
/// Ignore declarations.
7134+
bool walkToDeclPre(Decl *decl) override { return false; }
7135+
};
7136+
70957137
class ExprWalker : public ASTWalker {
70967138
ExprRewriter &Rewriter;
70977139
SmallVector<ClosureExpr *, 4> ClosuresToTypeCheck;
@@ -7129,8 +7171,13 @@ namespace {
71297171
[&](SolutionApplicationTarget target) {
71307172
auto resultTarget = rewriteTarget(target);
71317173
if (resultTarget) {
7132-
if (auto expr = resultTarget->getAsExpr())
7174+
7175+
if (auto expr = resultTarget->getAsExpr()) {
71337176
Rewriter.solution.setExprTypes(expr);
7177+
} else if (auto stmtCondition =
7178+
resultTarget->getAsStmtCondition()) {
7179+
7180+
}
71347181
}
71357182

71367183
return resultTarget;
@@ -7394,6 +7441,43 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
73947441
} else {
73957442
result.setExpr(rewrittenExpr);
73967443
}
7444+
} else if (auto stmtCondition = target.getAsStmtCondition()) {
7445+
for (auto &condElement : *stmtCondition) {
7446+
switch (condElement.getKind()) {
7447+
case StmtConditionElement::CK_Availability:
7448+
continue;
7449+
7450+
case StmtConditionElement::CK_Boolean: {
7451+
auto condExpr = condElement.getBoolean();
7452+
auto finalCondExpr = condExpr->walk(*this);
7453+
if (!finalCondExpr)
7454+
return None;
7455+
7456+
// Load the condition if needed.
7457+
if (finalCondExpr->getType()->hasLValueType()) {
7458+
ASTContext &ctx = solution.getConstraintSystem().getASTContext();
7459+
finalCondExpr = TypeChecker::addImplicitLoadExpr(ctx, finalCondExpr);
7460+
}
7461+
7462+
condElement.setBoolean(finalCondExpr);
7463+
continue;
7464+
}
7465+
7466+
case StmtConditionElement::CK_PatternBinding: {
7467+
ConstraintSystem &cs = solution.getConstraintSystem();
7468+
auto target = *cs.getStmtConditionTarget(&condElement);
7469+
auto resolvedTarget = rewriteTarget(target);
7470+
if (!resolvedTarget)
7471+
return None;
7472+
7473+
condElement.setInitializer(resolvedTarget->getAsExpr());
7474+
condElement.setPattern(resolvedTarget->getInitializationPattern());
7475+
continue;
7476+
}
7477+
}
7478+
}
7479+
7480+
return target;
73977481
} else {
73987482
auto fn = *target.getAsFunction();
73997483

@@ -7406,8 +7490,8 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
74067490
[&](SolutionApplicationTarget target) {
74077491
auto resultTarget = rewriteTarget(target);
74087492
if (resultTarget) {
7409-
if (auto expr = resultTarget->getAsExpr())
7410-
Rewriter.solution.setExprTypes(expr);
7493+
SetExprTypes typeSetter(solution);
7494+
resultTarget->walk(typeSetter);
74117495
}
74127496

74137497
return resultTarget;
@@ -7546,50 +7630,6 @@ Expr *Solution::coerceToType(Expr *expr, Type toType,
75467630
return result;
75477631
}
75487632

7549-
namespace {
7550-
class SetExprTypes : public ASTWalker {
7551-
const Solution &solution;
7552-
7553-
public:
7554-
explicit SetExprTypes(const Solution &solution)
7555-
: solution(solution) {}
7556-
7557-
Expr *walkToExprPost(Expr *expr) override {
7558-
auto &cs = solution.getConstraintSystem();
7559-
auto exprType = cs.getType(expr);
7560-
exprType = solution.simplifyType(exprType);
7561-
// assert((!expr->getType() || expr->getType()->isEqual(exprType)) &&
7562-
// "Mismatched types!");
7563-
assert(!exprType->hasTypeVariable() &&
7564-
"Should not write type variable into expression!");
7565-
expr->setType(exprType);
7566-
7567-
if (auto kp = dyn_cast<KeyPathExpr>(expr)) {
7568-
for (auto i : indices(kp->getComponents())) {
7569-
Type componentType;
7570-
if (cs.hasType(kp, i)) {
7571-
componentType = solution.simplifyType(cs.getType(kp, i));
7572-
assert(!componentType->hasTypeVariable() &&
7573-
"Should not write type variable into key-path component");
7574-
}
7575-
7576-
kp->getMutableComponents()[i].setComponentType(componentType);
7577-
}
7578-
}
7579-
7580-
return expr;
7581-
}
7582-
7583-
/// Ignore statements.
7584-
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
7585-
return { false, stmt };
7586-
}
7587-
7588-
/// Ignore declarations.
7589-
bool walkToDeclPre(Decl *decl) override { return false; }
7590-
};
7591-
}
7592-
75937633
ProtocolConformanceRef Solution::resolveConformance(
75947634
ConstraintLocator *locator, ProtocolDecl *proto) {
75957635
for (const auto &conformance : Conformances) {
@@ -7704,5 +7744,11 @@ SolutionApplicationTarget SolutionApplicationTarget::walk(ASTWalker &walker) {
77047744
return SolutionApplicationTarget(
77057745
*getAsFunction(),
77067746
cast_or_null<BraceStmt>(getFunctionBody()->walk(walker)));
7747+
7748+
case Kind::stmtCondition:
7749+
for (auto &condElement : stmtCondition.stmtCondition) {
7750+
condElement = *condElement.walk(walker);
7751+
}
7752+
return *this;
77077753
}
77087754
}

lib/Sema/ConstraintSystem.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,8 @@ struct DynamicCallableMethods {
11511151
class SolutionApplicationTarget {
11521152
enum class Kind {
11531153
expression,
1154-
function
1154+
function,
1155+
stmtCondition
11551156
} kind;
11561157

11571158
union {
@@ -1189,6 +1190,11 @@ class SolutionApplicationTarget {
11891190
AnyFunctionRef function;
11901191
BraceStmt *body;
11911192
} function;
1193+
1194+
struct {
1195+
StmtCondition stmtCondition;
1196+
DeclContext *dc;
1197+
} stmtCondition;
11921198
};
11931199

11941200
// If the pattern contains a single variable that has an attached
@@ -1211,6 +1217,12 @@ class SolutionApplicationTarget {
12111217
SolutionApplicationTarget(AnyFunctionRef fn)
12121218
: SolutionApplicationTarget(fn, fn.getBody()) { }
12131219

1220+
SolutionApplicationTarget(StmtCondition stmtCondition, DeclContext *dc) {
1221+
kind = Kind::stmtCondition;
1222+
this->stmtCondition.stmtCondition = stmtCondition;
1223+
this->stmtCondition.dc = dc;
1224+
}
1225+
12141226
SolutionApplicationTarget(AnyFunctionRef fn, BraceStmt *body) {
12151227
kind = Kind::function;
12161228
function.function = fn;
@@ -1228,6 +1240,7 @@ class SolutionApplicationTarget {
12281240
return expression.expression;
12291241

12301242
case Kind::function:
1243+
case Kind::stmtCondition:
12311244
return nullptr;
12321245
}
12331246
}
@@ -1239,6 +1252,9 @@ class SolutionApplicationTarget {
12391252

12401253
case Kind::function:
12411254
return function.function.getAsDeclContext();
1255+
1256+
case Kind::stmtCondition:
1257+
return stmtCondition.dc;
12421258
}
12431259
}
12441260

@@ -1346,13 +1362,25 @@ class SolutionApplicationTarget {
13461362
Optional<AnyFunctionRef> getAsFunction() const {
13471363
switch (kind) {
13481364
case Kind::expression:
1365+
case Kind::stmtCondition:
13491366
return None;
13501367

13511368
case Kind::function:
13521369
return function.function;
13531370
}
13541371
}
13551372

1373+
Optional<StmtCondition> getAsStmtCondition() const {
1374+
switch (kind) {
1375+
case Kind::expression:
1376+
case Kind::function:
1377+
return None;
1378+
1379+
case Kind::stmtCondition:
1380+
return stmtCondition.stmtCondition;
1381+
}
1382+
}
1383+
13561384
BraceStmt *getFunctionBody() const {
13571385
assert(kind == Kind::function);
13581386
return function.body;
@@ -1371,6 +1399,10 @@ class SolutionApplicationTarget {
13711399

13721400
case Kind::function:
13731401
return function.body->getSourceRange();
1402+
1403+
case Kind::stmtCondition:
1404+
return SourceRange(stmtCondition.stmtCondition.front().getStartLoc(),
1405+
stmtCondition.stmtCondition.back().getEndLoc());
13741406
}
13751407
}
13761408

@@ -1382,6 +1414,9 @@ class SolutionApplicationTarget {
13821414

13831415
case Kind::function:
13841416
return function.function.getLoc();
1417+
1418+
case Kind::stmtCondition:
1419+
return stmtCondition.stmtCondition.front().getStartLoc();
13851420
}
13861421
}
13871422

@@ -4339,6 +4374,10 @@ class ConstraintSystem {
43394374
Optional<SolutionApplicationTarget> applySolution(
43404375
Solution &solution, SolutionApplicationTarget target);
43414376

4377+
/// Apply the given solution to the given statement-condition.
4378+
Optional<StmtCondition> applySolution(
4379+
Solution &solution, StmtCondition condition, DeclContext *dc);
4380+
43424381
/// Reorder the disjunctive clauses for a given expression to
43434382
/// increase the likelihood that a favored constraint will be successfully
43444383
/// resolved before any others.

0 commit comments

Comments
 (0)