Skip to content

Commit 8191aa4

Browse files
authored
Merge pull request swiftlang#30174 from DougGregor/function-builder-switch
[Constraint system] Implement switch support for function builders.
2 parents 12fd864 + ea8d143 commit 8191aa4

File tree

9 files changed

+556
-33
lines changed

9 files changed

+556
-33
lines changed

lib/Sema/BuilderTransform.cpp

Lines changed: 153 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,12 +626,107 @@ class BuilderClosureVisitor
626626
DeclNameLoc(endLoc), /*implicit=*/true);
627627
}
628628

629+
VarDecl *visitSwitchStmt(SwitchStmt *switchStmt) {
630+
// Generate constraints for the subject expression, and capture its
631+
// type for use in matching the various patterns.
632+
Expr *subjectExpr = switchStmt->getSubjectExpr();
633+
if (cs) {
634+
// Form a one-way constraint to prevent backward propagation.
635+
subjectExpr = new (ctx) OneWayExpr(subjectExpr);
636+
637+
// FIXME: Add contextual type purpose for switch subjects?
638+
SolutionApplicationTarget target(subjectExpr, dc, CTP_Unused, Type(),
639+
/*isDiscarded=*/false);
640+
if (cs->generateConstraints(target, FreeTypeVariableBinding::Disallow)) {
641+
hadError = true;
642+
return nullptr;
643+
}
644+
645+
cs->setSolutionApplicationTarget(switchStmt, target);
646+
subjectExpr = target.getAsExpr();
647+
assert(subjectExpr && "Must have a subject expression here");
648+
}
649+
650+
// Generate constraints and capture variables for all of the cases.
651+
SmallVector<std::pair<CaseStmt *, VarDecl *>, 4> capturedCaseVars;
652+
for (auto *caseStmt : switchStmt->getCases()) {
653+
if (auto capturedCaseVar = visitCaseStmt(caseStmt, subjectExpr)) {
654+
capturedCaseVars.push_back({caseStmt, capturedCaseVar});
655+
}
656+
}
657+
658+
if (!cs)
659+
return nullptr;
660+
661+
// Form the expressions that inject the result of each case into the
662+
// appropriate
663+
llvm::TinyPtrVector<Expr *> injectedCaseExprs;
664+
SmallVector<std::pair<Type, ConstraintLocator *>, 4> injectedCaseTerms;
665+
for (unsigned idx : indices(capturedCaseVars)) {
666+
auto caseStmt = capturedCaseVars[idx].first;
667+
auto caseVar = capturedCaseVars[idx].second;
668+
669+
// Build the expression that injects the case variable into appropriate
670+
// buildEither(first:)/buildEither(second:) chain.
671+
Expr *caseVarRef = buildVarRef(caseVar, caseStmt->getEndLoc());
672+
Expr *injectedCaseExpr = buildWrappedChainPayload(
673+
caseVarRef, idx, capturedCaseVars.size(), /*isOptional=*/false);
674+
675+
// Generate constraints for this injected case result.
676+
injectedCaseExpr = cs->generateConstraints(injectedCaseExpr, dc);
677+
if (!injectedCaseExpr) {
678+
hadError = true;
679+
return nullptr;
680+
}
681+
682+
// Record this injected case expression.
683+
injectedCaseExprs.push_back(injectedCaseExpr);
684+
685+
// Record the type and locator for this injected case expression, to be
686+
// used in the "join" constraint later.
687+
injectedCaseTerms.push_back(
688+
{ cs->getType(injectedCaseExpr)->getRValueType(),
689+
cs->getConstraintLocator(injectedCaseExpr) });
690+
}
691+
692+
// Form the type of the switch itself.
693+
// FIXME: Need a locator for the "switch" statement.
694+
Type resultType = cs->addJoinConstraint(nullptr, injectedCaseTerms);
695+
if (!resultType) {
696+
hadError = true;
697+
return nullptr;
698+
}
699+
700+
// Create a variable to capture the result of evaluating the switch.
701+
auto switchVar = buildVar(switchStmt->getStartLoc());
702+
cs->setType(switchVar, resultType);
703+
applied.capturedStmts.insert(
704+
{switchStmt, { switchVar, std::move(injectedCaseExprs) } });
705+
return switchVar;
706+
}
707+
708+
VarDecl *visitCaseStmt(CaseStmt *caseStmt, Expr *subjectExpr) {
709+
// If needed, generate constraints for everything in the case statement.
710+
if (cs) {
711+
auto locator = cs->getConstraintLocator(
712+
subjectExpr, LocatorPathElt::ContextualType());
713+
Type subjectType = cs->getType(subjectExpr);
714+
715+
if (cs->generateConstraints(caseStmt, dc, subjectType, locator)) {
716+
hadError = true;
717+
return nullptr;
718+
}
719+
}
720+
721+
// Translate the body.
722+
return visit(caseStmt->getBody());
723+
}
724+
629725
CONTROL_FLOW_STMT(Guard)
630726
CONTROL_FLOW_STMT(While)
631727
CONTROL_FLOW_STMT(DoCatch)
632728
CONTROL_FLOW_STMT(RepeatWhile)
633729
CONTROL_FLOW_STMT(ForEach)
634-
CONTROL_FLOW_STMT(Switch)
635730
CONTROL_FLOW_STMT(Case)
636731
CONTROL_FLOW_STMT(Catch)
637732
CONTROL_FLOW_STMT(Break)
@@ -1000,6 +1095,63 @@ class BuilderClosureRewriter
10001095
return doStmt;
10011096
}
10021097

1098+
Stmt *visitSwitchStmt(SwitchStmt *switchStmt, FunctionBuilderTarget target) {
1099+
// Translate the subject expression.
1100+
ConstraintSystem &cs = solution.getConstraintSystem();
1101+
auto subjectTarget =
1102+
rewriteTarget(*cs.getSolutionApplicationTarget(switchStmt));
1103+
if (!subjectTarget)
1104+
return nullptr;
1105+
1106+
switchStmt->setSubjectExpr(subjectTarget->getAsExpr());
1107+
1108+
// Handle any declaration nodes within the case list first; we'll
1109+
// handle the cases in a second pass.
1110+
for (auto child : switchStmt->getRawCases()) {
1111+
if (auto decl = child.dyn_cast<Decl *>()) {
1112+
TypeChecker::typeCheckDecl(decl);
1113+
}
1114+
}
1115+
1116+
// Translate all of the cases.
1117+
assert(target.kind == FunctionBuilderTarget::TemporaryVar);
1118+
auto temporaryVar = target.captured.first;
1119+
unsigned caseIndex = 0;
1120+
for (auto caseStmt : switchStmt->getCases()) {
1121+
if (!visitCaseStmt(
1122+
caseStmt,
1123+
FunctionBuilderTarget::forAssign(
1124+
temporaryVar, {target.captured.second[caseIndex]})))
1125+
return nullptr;
1126+
1127+
++caseIndex;
1128+
}
1129+
1130+
return switchStmt;
1131+
}
1132+
1133+
Stmt *visitCaseStmt(CaseStmt *caseStmt, FunctionBuilderTarget target) {
1134+
// Translate the patterns and guard expressions for each case label item.
1135+
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
1136+
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, dc);
1137+
if (!rewriteTarget(caseLabelTarget))
1138+
return nullptr;
1139+
}
1140+
1141+
// Transform the body of the case.
1142+
auto body = cast<BraceStmt>(caseStmt->getBody());
1143+
auto captured = takeCapturedStmt(body);
1144+
auto newInnerBody = cast<BraceStmt>(
1145+
visitBraceStmt(
1146+
body,
1147+
target,
1148+
FunctionBuilderTarget::forAssign(
1149+
captured.first, {captured.second.front()})));
1150+
caseStmt->setBody(newInnerBody);
1151+
1152+
return caseStmt;
1153+
}
1154+
10031155
#define UNHANDLED_FUNCTION_BUILDER_STMT(STMT) \
10041156
Stmt *visit##STMT##Stmt(STMT##Stmt *stmt, FunctionBuilderTarget target) { \
10051157
llvm_unreachable("Function builders do not allow statement of kind " \
@@ -1014,8 +1166,6 @@ class BuilderClosureRewriter
10141166
UNHANDLED_FUNCTION_BUILDER_STMT(DoCatch)
10151167
UNHANDLED_FUNCTION_BUILDER_STMT(RepeatWhile)
10161168
UNHANDLED_FUNCTION_BUILDER_STMT(ForEach)
1017-
UNHANDLED_FUNCTION_BUILDER_STMT(Switch)
1018-
UNHANDLED_FUNCTION_BUILDER_STMT(Case)
10191169
UNHANDLED_FUNCTION_BUILDER_STMT(Catch)
10201170
UNHANDLED_FUNCTION_BUILDER_STMT(Break)
10211171
UNHANDLED_FUNCTION_BUILDER_STMT(Continue)

lib/Sema/CSApply.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7461,7 +7461,7 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
74617461

74627462
case StmtConditionElement::CK_PatternBinding: {
74637463
ConstraintSystem &cs = solution.getConstraintSystem();
7464-
auto target = *cs.getStmtConditionTarget(&condElement);
7464+
auto target = *cs.getSolutionApplicationTarget(&condElement);
74657465
auto resolvedTarget = rewriteTarget(target);
74667466
if (!resolvedTarget)
74677467
return None;
@@ -7474,6 +7474,45 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
74747474
}
74757475
}
74767476

7477+
return target;
7478+
} else if (auto caseLabelItem = target.getAsCaseLabelItem()) {
7479+
ConstraintSystem &cs = solution.getConstraintSystem();
7480+
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
7481+
7482+
// Figure out the pattern type.
7483+
Type patternType = solution.simplifyType(solution.getType(info.pattern));
7484+
patternType = patternType->reconstituteSugar(/*recursive=*/false);
7485+
7486+
// Coerce the pattern to its appropriate type.
7487+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
7488+
patternOptions |= TypeResolutionFlags::OverrideType;
7489+
auto contextualPattern =
7490+
ContextualPattern::forRawPattern(info.pattern,
7491+
target.getDeclContext());
7492+
if (auto coercedPattern = TypeChecker::coercePatternToType(
7493+
contextualPattern, patternType, patternOptions)) {
7494+
(*caseLabelItem)->setPattern(coercedPattern);
7495+
} else {
7496+
return None;
7497+
}
7498+
7499+
// If there is a guard expression, coerce that.
7500+
if (auto guardExpr = info.guardExpr) {
7501+
guardExpr = guardExpr->walk(*this);
7502+
if (!guardExpr)
7503+
return None;
7504+
7505+
// FIXME: Feels like we could leverage existing code more.
7506+
Type boolType = cs.getASTContext().getBoolDecl()->getDeclaredType();
7507+
guardExpr = solution.coerceToType(
7508+
guardExpr, boolType, cs.getConstraintLocator(info.guardExpr));
7509+
if (!guardExpr)
7510+
return None;
7511+
7512+
(*caseLabelItem)->setGuardExpr(guardExpr);
7513+
solution.setExprTypes(guardExpr);
7514+
}
7515+
74777516
return target;
74787517
} else {
74797518
auto fn = *target.getAsFunction();
@@ -7748,5 +7787,17 @@ SolutionApplicationTarget SolutionApplicationTarget::walk(ASTWalker &walker) {
77487787
condElement = *condElement.walk(walker);
77497788
}
77507789
return *this;
7790+
7791+
case Kind::caseLabelItem:
7792+
if (auto newPattern =
7793+
caseLabelItem.caseLabelItem->getPattern()->walk(walker)) {
7794+
caseLabelItem.caseLabelItem->setPattern(newPattern);
7795+
}
7796+
if (auto guardExpr = caseLabelItem.caseLabelItem->getGuardExpr()) {
7797+
if (auto newGuardExpr = guardExpr->walk(walker))
7798+
caseLabelItem.caseLabelItem->setGuardExpr(newGuardExpr);
7799+
}
7800+
7801+
return *this;
77517802
}
77527803
}

lib/Sema/CSGen.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4218,7 +4218,7 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition,
42184218
if (generateConstraints(target, FreeTypeVariableBinding::Disallow))
42194219
return true;
42204220

4221-
setStmtConditionTarget(&condElement, target);
4221+
setSolutionApplicationTarget(&condElement, target);
42224222
continue;
42234223
}
42244224
}
@@ -4227,6 +4227,62 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition,
42274227
return false;
42284228
}
42294229

4230+
bool ConstraintSystem::generateConstraints(
4231+
CaseStmt *caseStmt, DeclContext *dc, Type subjectType,
4232+
ConstraintLocator *locator) {
4233+
// Pre-bind all of the pattern variables within the case.
4234+
bindSwitchCasePatternVars(caseStmt);
4235+
4236+
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
4237+
// Resolve the pattern.
4238+
auto *pattern = TypeChecker::resolvePattern(
4239+
caseLabelItem.getPattern(), dc, /*isStmtCondition=*/false);
4240+
if (!pattern)
4241+
return true;
4242+
4243+
// Generate constraints for the pattern, including one-way bindings for
4244+
// any variables that show up in this pattern, because those variables
4245+
// can be referenced in the guard expressions and the body.
4246+
Type patternType = generateConstraints(
4247+
pattern, locator, /* bindPatternVarsOneWay=*/true);
4248+
4249+
// Convert the subject type to the pattern, which establishes the
4250+
// bindings.
4251+
addConstraint(
4252+
ConstraintKind::Conversion, subjectType, patternType, locator);
4253+
4254+
// Generate constraints for the guard expression, if there is one.
4255+
Expr *guardExpr = caseLabelItem.getGuardExpr();
4256+
if (guardExpr) {
4257+
guardExpr = generateConstraints(guardExpr, dc);
4258+
if (!guardExpr)
4259+
return true;
4260+
}
4261+
4262+
// Save this info.
4263+
setCaseLabelItemInfo(&caseLabelItem, {pattern, guardExpr});
4264+
4265+
// For any pattern variable that has a parent variable (i.e., another
4266+
// pattern variable with the same name in the same case), require that
4267+
// the types be equivalent.
4268+
pattern->forEachVariable([&](VarDecl *var) {
4269+
if (auto parentVar = var->getParentVarDecl()) {
4270+
addConstraint(
4271+
ConstraintKind::Equal, getType(parentVar), getType(var), locator);
4272+
}
4273+
});
4274+
}
4275+
4276+
// Bind the types of the case body variables.
4277+
for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
4278+
auto parentVar = caseBodyVar->getParentVarDecl();
4279+
assert(parentVar && "Case body variables always have parents");
4280+
setType(caseBodyVar, getType(parentVar));
4281+
}
4282+
4283+
return false;
4284+
}
4285+
42304286
void ConstraintSystem::optimizeConstraints(Expr *e) {
42314287
if (getASTContext().TypeCheckerOpts.DisableConstraintSolverPerformanceHacks)
42324288
return;

lib/Sema/CSSolver.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ Solution ConstraintSystem::finalize() {
170170
solution.contextualTypes.assign(
171171
contextualTypes.begin(), contextualTypes.end());
172172

173-
solution.stmtConditionTargets = stmtConditionTargets;
173+
solution.solutionApplicationTargets = solutionApplicationTargets;
174+
solution.caseLabelItems = caseLabelItems;
174175

175176
for (auto &e : CheckedConformances)
176177
solution.Conformances.push_back({e.first, e.second});
@@ -244,9 +245,15 @@ void ConstraintSystem::applySolution(const Solution &solution) {
244245
}
245246

246247
// Register the statement condition targets.
247-
for (const auto &stmtConditionTarget : solution.stmtConditionTargets) {
248-
if (!getStmtConditionTarget(stmtConditionTarget.first))
249-
setStmtConditionTarget(stmtConditionTarget.first, stmtConditionTarget.second);
248+
for (const auto &target : solution.solutionApplicationTargets) {
249+
if (!getSolutionApplicationTarget(target.first))
250+
setSolutionApplicationTarget(target.first, target.second);
251+
}
252+
253+
// Register the statement condition targets.
254+
for (const auto &info : solution.caseLabelItems) {
255+
if (!getCaseLabelItemInfo(info.first))
256+
setCaseLabelItemInfo(info.first, info.second);
250257
}
251258

252259
// Register the conformances checked along the way to arrive to solution.
@@ -354,6 +361,13 @@ void truncate(llvm::MapVector<K, V> &map, unsigned newSize) {
354361
map.pop_back();
355362
}
356363

364+
template <typename K, typename V, unsigned N>
365+
void truncate(llvm::SmallMapVector<K, V, N> &map, unsigned newSize) {
366+
assert(newSize <= map.size() && "Not a truncation!");
367+
for (unsigned i = 0, n = map.size() - newSize; i != n; ++i)
368+
map.pop_back();
369+
}
370+
357371
} // end anonymous namespace
358372

359373
ConstraintSystem::SolverState::SolverState(
@@ -461,7 +475,8 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
461475
numResolvedOverloads = cs.ResolvedOverloads.size();
462476
numInferredClosureTypes = cs.ClosureTypes.size();
463477
numContextualTypes = cs.contextualTypes.size();
464-
numStmtConditionTargets = cs.stmtConditionTargets.size();
478+
numSolutionApplicationTargets = cs.solutionApplicationTargets.size();
479+
numCaseLabelItems = cs.caseLabelItems.size();
465480

466481
PreviousScore = cs.CurrentScore;
467482

@@ -539,8 +554,11 @@ ConstraintSystem::SolverScope::~SolverScope() {
539554
// Remove any contextual types.
540555
truncate(cs.contextualTypes, numContextualTypes);
541556

542-
// Remove any statement condition types.
543-
truncate(cs.stmtConditionTargets, numStmtConditionTargets);
557+
// Remove any solution application targets.
558+
truncate(cs.solutionApplicationTargets, numSolutionApplicationTargets);
559+
560+
// Remove any case label item infos.
561+
truncate(cs.caseLabelItems, numCaseLabelItems);
544562

545563
// Reset the previous score.
546564
cs.CurrentScore = PreviousScore;

0 commit comments

Comments
 (0)