Skip to content

Commit 6999c31

Browse files
authored
Merge pull request swiftlang#30924 from DougGregor/for-each-solution-application-target
[Constraint solver] Migrate for-each statement checking into SolutionApplicationTarget
2 parents e9ed2d5 + 1f232f7 commit 6999c31

File tree

7 files changed

+395
-339
lines changed

7 files changed

+395
-339
lines changed

lib/Sema/CSApply.cpp

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7887,7 +7887,7 @@ bool ConstraintSystem::applySolutionFixes(const Solution &solution) {
78877887

78887888
/// Apply the given solution to the initialization target.
78897889
///
7890-
/// \returns the resulting initialiation expression.
7890+
/// \returns the resulting initialization expression.
78917891
static Optional<SolutionApplicationTarget> applySolutionToInitialization(
78927892
Solution &solution, SolutionApplicationTarget target,
78937893
Expr *initializer) {
@@ -7951,7 +7951,7 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
79517951
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
79527952

79537953
// Apply the solution to the pattern as well.
7954-
auto contextualPattern = target.getInitializationContextualPattern();
7954+
auto contextualPattern = target.getContextualPattern();
79557955
if (auto coercedPattern = TypeChecker::coercePatternToType(
79567956
contextualPattern, finalPatternType, options)) {
79577957
resultTarget.setPattern(coercedPattern);
@@ -7962,6 +7962,139 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
79627962
return resultTarget;
79637963
}
79647964

7965+
/// Apply the given solution to the for-each statement target.
7966+
///
7967+
/// \returns the resulting initialization expression.
7968+
static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
7969+
Solution &solution, SolutionApplicationTarget target, Expr *sequence) {
7970+
auto resultTarget = target;
7971+
auto &forEachStmtInfo = resultTarget.getForEachStmtInfo();
7972+
7973+
// Simplify the various types.
7974+
forEachStmtInfo.elementType =
7975+
solution.simplifyType(forEachStmtInfo.elementType);
7976+
forEachStmtInfo.iteratorType =
7977+
solution.simplifyType(forEachStmtInfo.iteratorType);
7978+
forEachStmtInfo.initType =
7979+
solution.simplifyType(forEachStmtInfo.initType);
7980+
forEachStmtInfo.sequenceType =
7981+
solution.simplifyType(forEachStmtInfo.sequenceType);
7982+
7983+
// Coerce the sequence to the sequence type.
7984+
auto &cs = solution.getConstraintSystem();
7985+
auto locator = cs.getConstraintLocator(target.getAsExpr());
7986+
sequence = solution.coerceToType(
7987+
sequence, forEachStmtInfo.sequenceType, locator);
7988+
if (!sequence)
7989+
return None;
7990+
7991+
resultTarget.setExpr(sequence);
7992+
7993+
// Get the conformance of the sequence type to the Sequence protocol.
7994+
auto stmt = forEachStmtInfo.stmt;
7995+
auto sequenceProto = TypeChecker::getProtocol(
7996+
cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence);
7997+
auto contextualLocator = cs.getConstraintLocator(
7998+
target.getAsExpr(), LocatorPathElt::ContextualType());
7999+
auto sequenceConformance = solution.resolveConformance(
8000+
contextualLocator, sequenceProto);
8001+
assert(!sequenceConformance.isInvalid() &&
8002+
"Couldn't find sequence conformance");
8003+
8004+
// Coerce the pattern to the element type.
8005+
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
8006+
options |= TypeResolutionFlags::OverrideType;
8007+
8008+
// Apply the solution to the pattern as well.
8009+
auto contextualPattern = target.getContextualPattern();
8010+
if (auto coercedPattern = TypeChecker::coercePatternToType(
8011+
contextualPattern, forEachStmtInfo.initType, options)) {
8012+
resultTarget.setPattern(coercedPattern);
8013+
} else {
8014+
return None;
8015+
}
8016+
8017+
// Apply the solution to the filtering condition, if there is one.
8018+
auto dc = target.getDeclContext();
8019+
if (forEachStmtInfo.whereExpr) {
8020+
auto *boolDecl = dc->getASTContext().getBoolDecl();
8021+
assert(boolDecl);
8022+
Type boolType = boolDecl->getDeclaredType();
8023+
assert(boolType);
8024+
8025+
SolutionApplicationTarget whereTarget(
8026+
forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType,
8027+
/*isDiscarded=*/false);
8028+
auto newWhereTarget = cs.applySolution(solution, whereTarget);
8029+
if (!newWhereTarget)
8030+
return None;
8031+
8032+
forEachStmtInfo.whereExpr = newWhereTarget->getAsExpr();
8033+
}
8034+
8035+
// Invoke iterator() to get an iterator from the sequence.
8036+
ASTContext &ctx = cs.getASTContext();
8037+
VarDecl *iterator;
8038+
Type nextResultType = OptionalType::get(forEachStmtInfo.elementType);
8039+
{
8040+
// Create a local variable to capture the iterator.
8041+
std::string name;
8042+
if (auto np = dyn_cast_or_null<NamedPattern>(stmt->getPattern()))
8043+
name = "$"+np->getBoundName().str().str();
8044+
name += "$generator";
8045+
8046+
iterator = new (ctx) VarDecl(
8047+
/*IsStatic*/ false, VarDecl::Introducer::Var,
8048+
/*IsCaptureList*/ false, stmt->getInLoc(),
8049+
ctx.getIdentifier(name), dc);
8050+
iterator->setInterfaceType(
8051+
forEachStmtInfo.iteratorType->mapTypeOutOfContext());
8052+
iterator->setImplicit();
8053+
8054+
auto genPat = new (ctx) NamedPattern(iterator);
8055+
genPat->setImplicit();
8056+
8057+
// TODO: test/DebugInfo/iteration.swift requires this extra info to
8058+
// be around.
8059+
PatternBindingDecl::createImplicit(
8060+
ctx, StaticSpellingKind::None, genPat,
8061+
new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType),
8062+
dc, /*VarLoc*/ stmt->getForLoc());
8063+
}
8064+
8065+
// Create the iterator variable.
8066+
auto *varRef = TypeChecker::buildCheckedRefExpr(
8067+
iterator, dc, DeclNameLoc(stmt->getInLoc()), /*implicit*/ true);
8068+
8069+
// Convert that Optional<Element> value to the type of the pattern.
8070+
auto optPatternType = OptionalType::get(forEachStmtInfo.initType);
8071+
if (!optPatternType->isEqual(nextResultType)) {
8072+
OpaqueValueExpr *elementExpr =
8073+
new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType,
8074+
/*isPlaceholder=*/true);
8075+
Expr *convertElementExpr = elementExpr;
8076+
if (TypeChecker::typeCheckExpression(
8077+
convertElementExpr, dc,
8078+
TypeLoc::withoutLoc(optPatternType),
8079+
CTP_CoerceOperand).isNull()) {
8080+
return None;
8081+
}
8082+
elementExpr->setIsPlaceholder(false);
8083+
stmt->setElementExpr(elementExpr);
8084+
stmt->setConvertElementExpr(convertElementExpr);
8085+
}
8086+
8087+
// Write the result back into the AST.
8088+
stmt->setSequence(resultTarget.getAsExpr());
8089+
stmt->setPattern(resultTarget.getContextualPattern().getPattern());
8090+
stmt->setSequenceConformance(sequenceConformance);
8091+
stmt->setWhere(forEachStmtInfo.whereExpr);
8092+
stmt->setIteratorVar(iterator);
8093+
stmt->setIteratorVarRef(varRef);
8094+
8095+
return resultTarget;
8096+
}
8097+
79658098
Optional<SolutionApplicationTarget>
79668099
ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
79678100
auto &solution = Rewriter.solution;
@@ -7973,16 +8106,50 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
79738106
if (!rewrittenExpr)
79748107
return None;
79758108

7976-
/// Handle application for initializations.
7977-
if (target.getExprContextualTypePurpose() == CTP_Initialization) {
8109+
/// Handle special cases for expressions.
8110+
switch (target.getExprContextualTypePurpose()) {
8111+
case CTP_Initialization: {
79788112
auto initResultTarget = applySolutionToInitialization(
79798113
solution, target, rewrittenExpr);
79808114
if (!initResultTarget)
79818115
return None;
79828116

79838117
result = *initResultTarget;
7984-
} else {
8118+
break;
8119+
}
8120+
8121+
case CTP_ForEachStmt: {
8122+
auto forEachResultTarget = applySolutionToForEachStmt(
8123+
solution, target, rewrittenExpr);
8124+
if (!forEachResultTarget)
8125+
return None;
8126+
8127+
result = *forEachResultTarget;
8128+
break;
8129+
}
8130+
8131+
case CTP_Unused:
8132+
case CTP_ReturnStmt:
8133+
case swift::CTP_ReturnSingleExpr:
8134+
case swift::CTP_YieldByValue:
8135+
case swift::CTP_YieldByReference:
8136+
case swift::CTP_ThrowStmt:
8137+
case swift::CTP_EnumCaseRawValue:
8138+
case swift::CTP_DefaultParameter:
8139+
case swift::CTP_AutoclosureDefaultParameter:
8140+
case swift::CTP_CalleeResult:
8141+
case swift::CTP_CallArgument:
8142+
case swift::CTP_ClosureResult:
8143+
case swift::CTP_ArrayElement:
8144+
case swift::CTP_DictionaryKey:
8145+
case swift::CTP_DictionaryValue:
8146+
case swift::CTP_CoerceOperand:
8147+
case swift::CTP_AssignSource:
8148+
case swift::CTP_SubscriptAssignSource:
8149+
case swift::CTP_Condition:
8150+
case swift::CTP_CannotFail:
79858151
result.setExpr(rewrittenExpr);
8152+
break;
79868153
}
79878154
} else if (auto stmtCondition = target.getAsStmtCondition()) {
79888155
for (auto &condElement : *stmtCondition) {

lib/Sema/CSGen.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4133,6 +4133,112 @@ static bool generateInitPatternConstraints(
41334133
return false;
41344134
}
41354135

4136+
/// Generate constraints for a for-each statement.
4137+
static Optional<SolutionApplicationTarget>
4138+
generateForEachStmtConstraints(
4139+
ConstraintSystem &cs, SolutionApplicationTarget target, Expr *sequence) {
4140+
auto forEachStmtInfo = target.getForEachStmtInfo();
4141+
ForEachStmt *stmt = forEachStmtInfo.stmt;
4142+
4143+
auto locator = cs.getConstraintLocator(sequence);
4144+
auto contextualLocator =
4145+
cs.getConstraintLocator(sequence, LocatorPathElt::ContextualType());
4146+
4147+
// The expression type must conform to the Sequence protocol.
4148+
auto sequenceProto = TypeChecker::getProtocol(
4149+
cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence);
4150+
if (!sequenceProto) {
4151+
return None;
4152+
}
4153+
4154+
Type sequenceType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
4155+
cs.addConstraint(ConstraintKind::Conversion, cs.getType(sequence),
4156+
sequenceType, locator);
4157+
cs.addConstraint(ConstraintKind::ConformsTo, sequenceType,
4158+
sequenceProto->getDeclaredType(), contextualLocator);
4159+
4160+
// Check the element pattern.
4161+
ASTContext &ctx = cs.getASTContext();
4162+
auto dc = target.getDeclContext();
4163+
Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc,
4164+
/*isStmtCondition*/false);
4165+
if (!pattern)
4166+
return None;
4167+
4168+
auto contextualPattern =
4169+
ContextualPattern::forRawPattern(pattern, dc);
4170+
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);
4171+
if (patternType->hasError()) {
4172+
return None;
4173+
}
4174+
4175+
// Collect constraints from the element pattern.
4176+
auto elementLocator = cs.getConstraintLocator(
4177+
contextualLocator, ConstraintLocator::SequenceElementType);
4178+
Type initType = cs.generateConstraints(
4179+
pattern, contextualLocator, target.shouldBindPatternVarsOneWay(),
4180+
nullptr, 0);
4181+
if (!initType)
4182+
return None;
4183+
4184+
// Add a conversion constraint between the element type of the sequence
4185+
// and the type of the element pattern.
4186+
auto elementAssocType =
4187+
sequenceProto->getAssociatedType(cs.getASTContext().Id_Element);
4188+
Type elementType = DependentMemberType::get(sequenceType, elementAssocType);
4189+
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4190+
elementLocator);
4191+
4192+
// Determine the iterator type.
4193+
auto iteratorAssocType =
4194+
sequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator);
4195+
Type iteratorType = DependentMemberType::get(sequenceType, iteratorAssocType);
4196+
4197+
// The iterator type must conform to IteratorProtocol.
4198+
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
4199+
cs.getASTContext(), stmt->getForLoc(),
4200+
KnownProtocolKind::IteratorProtocol);
4201+
if (!iteratorProto)
4202+
return None;
4203+
4204+
// Reference the makeIterator witness.
4205+
FuncDecl *makeIterator = ctx.getSequenceMakeIterator();
4206+
Type makeIteratorType =
4207+
cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
4208+
cs.addValueWitnessConstraint(
4209+
LValueType::get(sequenceType), makeIterator,
4210+
makeIteratorType, dc, FunctionRefKind::Compound,
4211+
contextualLocator);
4212+
4213+
// Generate constraints for the "where" expression, if there is one.
4214+
if (forEachStmtInfo.whereExpr) {
4215+
auto *boolDecl = dc->getASTContext().getBoolDecl();
4216+
if (!boolDecl)
4217+
return None;
4218+
4219+
Type boolType = boolDecl->getDeclaredType();
4220+
if (!boolType)
4221+
return None;
4222+
4223+
SolutionApplicationTarget whereTarget(
4224+
forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType,
4225+
/*isDiscarded=*/false);
4226+
if (cs.generateConstraints(whereTarget, FreeTypeVariableBinding::Disallow))
4227+
return None;
4228+
4229+
forEachStmtInfo.whereExpr = whereTarget.getAsExpr();
4230+
}
4231+
4232+
// Populate all of the information for a for-each loop.
4233+
forEachStmtInfo.elementType = elementType;
4234+
forEachStmtInfo.iteratorType = iteratorType;
4235+
forEachStmtInfo.initType = initType;
4236+
forEachStmtInfo.sequenceType = sequenceType;
4237+
target.setPattern(pattern);
4238+
target.getForEachStmtInfo() = forEachStmtInfo;
4239+
return target;
4240+
}
4241+
41364242
bool ConstraintSystem::generateConstraints(
41374243
SolutionApplicationTarget &target,
41384244
FreeTypeVariableBinding allowFreeTypeVariables) {
@@ -4191,6 +4297,16 @@ bool ConstraintSystem::generateConstraints(
41914297
return true;
41924298
}
41934299

4300+
// For a for-each statement, generate constraints for the pattern, where
4301+
// clause, and sequence traversal.
4302+
if (target.getExprContextualTypePurpose() == CTP_ForEachStmt) {
4303+
auto resultTarget = generateForEachStmtConstraints(*this, target, expr);
4304+
if (!resultTarget)
4305+
return true;
4306+
4307+
target = *resultTarget;
4308+
}
4309+
41944310
if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
41954311
auto &log = getASTContext().TypeCheckerDebug->getStream();
41964312
log << "---Initial constraints for the given expression---\n";

lib/Sema/CSSolver.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,6 @@ static bool debugConstraintSolverForTarget(
11411141

11421142
Optional<std::vector<Solution>> ConstraintSystem::solve(
11431143
SolutionApplicationTarget &target,
1144-
ExprTypeCheckListener *listener,
11451144
FreeTypeVariableBinding allowFreeTypeVariables
11461145
) {
11471146
llvm::SaveAndRestore<bool> debugForExpr(
@@ -1171,7 +1170,7 @@ Optional<std::vector<Solution>> ConstraintSystem::solve(
11711170
// when there is an error and attempts to salvage an ill-formed program.
11721171
for (unsigned stage = 0; stage != 2; ++stage) {
11731172
auto solution = (stage == 0)
1174-
? solveImpl(target, listener, allowFreeTypeVariables)
1173+
? solveImpl(target, allowFreeTypeVariables)
11751174
: salvage();
11761175

11771176
switch (solution.getKind()) {
@@ -1237,7 +1236,6 @@ Optional<std::vector<Solution>> ConstraintSystem::solve(
12371236

12381237
SolutionResult
12391238
ConstraintSystem::solveImpl(SolutionApplicationTarget &target,
1240-
ExprTypeCheckListener *listener,
12411239
FreeTypeVariableBinding allowFreeTypeVariables) {
12421240
if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
12431241
auto &log = getASTContext().TypeCheckerDebug->getStream();
@@ -1260,13 +1258,6 @@ ConstraintSystem::solveImpl(SolutionApplicationTarget &target,
12601258
if (generateConstraints(target, allowFreeTypeVariables))
12611259
return SolutionResult::forError();;
12621260

1263-
// Notify the listener that we've built the constraint system.
1264-
if (Expr *expr = target.getAsExpr()) {
1265-
if (listener && listener->builtConstraints(*this, expr)) {
1266-
return SolutionResult::forError();
1267-
}
1268-
}
1269-
12701261
// Try to solve the constraint system using computed suggestions.
12711262
SmallVector<Solution, 4> solutions;
12721263
solve(solutions, allowFreeTypeVariables);

0 commit comments

Comments
 (0)