Skip to content

Commit 87d86f3

Browse files
committed
[Constraint solver] Migrate for-each statement checking into SolutionApplicationTarget.
Pull the entirety of type checking for for-each statement headers (i.e., not the body) into the constraint system, using the normal SolutionApplicationTarget-based constraint generation and application facilities. Most of this was already handled in the constraint solver (although the `where` filtering condition was not), so this is a smaller change than it looks like.
1 parent 31e7873 commit 87d86f3

File tree

5 files changed

+384
-256
lines changed

5 files changed

+384
-256
lines changed

lib/Sema/CSApply.cpp

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7886,7 +7886,7 @@ bool ConstraintSystem::applySolutionFixes(const Solution &solution) {
78867886

78877887
/// Apply the given solution to the initialization target.
78887888
///
7889-
/// \returns the resulting initialiation expression.
7889+
/// \returns the resulting initialization expression.
78907890
static Optional<SolutionApplicationTarget> applySolutionToInitialization(
78917891
Solution &solution, SolutionApplicationTarget target,
78927892
Expr *initializer) {
@@ -7950,7 +7950,7 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
79507950
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
79517951

79527952
// Apply the solution to the pattern as well.
7953-
auto contextualPattern = target.getInitializationContextualPattern();
7953+
auto contextualPattern = target.getContextualPattern();
79547954
if (auto coercedPattern = TypeChecker::coercePatternToType(
79557955
contextualPattern, finalPatternType, options)) {
79567956
resultTarget.setPattern(coercedPattern);
@@ -7961,6 +7961,139 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
79617961
return resultTarget;
79627962
}
79637963

7964+
/// Apply the given solution to the for-each statement target.
7965+
///
7966+
/// \returns the resulting initialization expression.
7967+
static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
7968+
Solution &solution, SolutionApplicationTarget target, Expr *sequence) {
7969+
auto resultTarget = target;
7970+
auto &forEachStmtInfo = resultTarget.getForEachStmtInfo();
7971+
7972+
// Simplify the various types.
7973+
forEachStmtInfo.elementType =
7974+
solution.simplifyType(forEachStmtInfo.elementType);
7975+
forEachStmtInfo.iteratorType =
7976+
solution.simplifyType(forEachStmtInfo.iteratorType);
7977+
forEachStmtInfo.initType =
7978+
solution.simplifyType(forEachStmtInfo.initType);
7979+
forEachStmtInfo.sequenceType =
7980+
solution.simplifyType(forEachStmtInfo.sequenceType);
7981+
7982+
// Coerce the sequence to the sequence type.
7983+
auto &cs = solution.getConstraintSystem();
7984+
auto locator = cs.getConstraintLocator(target.getAsExpr());
7985+
sequence = solution.coerceToType(
7986+
sequence, forEachStmtInfo.sequenceType, locator);
7987+
if (!sequence)
7988+
return None;
7989+
7990+
resultTarget.setExpr(sequence);
7991+
7992+
// Get the conformance of the sequence type to the Sequence protocol.
7993+
auto stmt = forEachStmtInfo.stmt;
7994+
auto sequenceProto = TypeChecker::getProtocol(
7995+
cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence);
7996+
auto contextualLocator = cs.getConstraintLocator(
7997+
target.getAsExpr(), LocatorPathElt::ContextualType());
7998+
auto sequenceConformance = solution.resolveConformance(
7999+
contextualLocator, sequenceProto);
8000+
assert(!sequenceConformance.isInvalid() &&
8001+
"Couldn't find sequence conformance");
8002+
8003+
// Coerce the pattern to the element type.
8004+
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
8005+
options |= TypeResolutionFlags::OverrideType;
8006+
8007+
// Apply the solution to the pattern as well.
8008+
auto contextualPattern = target.getContextualPattern();
8009+
if (auto coercedPattern = TypeChecker::coercePatternToType(
8010+
contextualPattern, forEachStmtInfo.initType, options)) {
8011+
resultTarget.setPattern(coercedPattern);
8012+
} else {
8013+
return None;
8014+
}
8015+
8016+
// Apply the solution to the filtering condition, if there is one.
8017+
auto dc = target.getDeclContext();
8018+
if (forEachStmtInfo.whereExpr) {
8019+
auto *boolDecl = dc->getASTContext().getBoolDecl();
8020+
assert(boolDecl);
8021+
Type boolType = boolDecl->getDeclaredType();
8022+
assert(boolType);
8023+
8024+
SolutionApplicationTarget whereTarget(
8025+
forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType,
8026+
/*isDiscarded=*/false);
8027+
auto newWhereTarget = cs.applySolution(solution, whereTarget);
8028+
if (!newWhereTarget)
8029+
return None;
8030+
8031+
forEachStmtInfo.whereExpr = newWhereTarget->getAsExpr();
8032+
}
8033+
8034+
// Invoke iterator() to get an iterator from the sequence.
8035+
ASTContext &ctx = cs.getASTContext();
8036+
VarDecl *iterator;
8037+
Type nextResultType = OptionalType::get(forEachStmtInfo.elementType);
8038+
{
8039+
// Create a local variable to capture the iterator.
8040+
std::string name;
8041+
if (auto np = dyn_cast_or_null<NamedPattern>(stmt->getPattern()))
8042+
name = "$"+np->getBoundName().str().str();
8043+
name += "$generator";
8044+
8045+
iterator = new (ctx) VarDecl(
8046+
/*IsStatic*/ false, VarDecl::Introducer::Var,
8047+
/*IsCaptureList*/ false, stmt->getInLoc(),
8048+
ctx.getIdentifier(name), dc);
8049+
iterator->setInterfaceType(
8050+
forEachStmtInfo.iteratorType->mapTypeOutOfContext());
8051+
iterator->setImplicit();
8052+
8053+
auto genPat = new (ctx) NamedPattern(iterator);
8054+
genPat->setImplicit();
8055+
8056+
// TODO: test/DebugInfo/iteration.swift requires this extra info to
8057+
// be around.
8058+
PatternBindingDecl::createImplicit(
8059+
ctx, StaticSpellingKind::None, genPat,
8060+
new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType),
8061+
dc, /*VarLoc*/ stmt->getForLoc());
8062+
}
8063+
8064+
// Create the iterator variable.
8065+
auto *varRef = TypeChecker::buildCheckedRefExpr(
8066+
iterator, dc, DeclNameLoc(stmt->getInLoc()), /*implicit*/ true);
8067+
8068+
// Convert that Optional<Element> value to the type of the pattern.
8069+
auto optPatternType = OptionalType::get(forEachStmtInfo.initType);
8070+
if (!optPatternType->isEqual(nextResultType)) {
8071+
OpaqueValueExpr *elementExpr =
8072+
new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType,
8073+
/*isPlaceholder=*/true);
8074+
Expr *convertElementExpr = elementExpr;
8075+
if (TypeChecker::typeCheckExpression(
8076+
convertElementExpr, dc,
8077+
TypeLoc::withoutLoc(optPatternType),
8078+
CTP_CoerceOperand).isNull()) {
8079+
return None;
8080+
}
8081+
elementExpr->setIsPlaceholder(false);
8082+
stmt->setElementExpr(elementExpr);
8083+
stmt->setConvertElementExpr(convertElementExpr);
8084+
}
8085+
8086+
// Write the result back into the AST.
8087+
stmt->setSequence(resultTarget.getAsExpr());
8088+
stmt->setPattern(resultTarget.getContextualPattern().getPattern());
8089+
stmt->setSequenceConformance(sequenceConformance);
8090+
stmt->setWhere(forEachStmtInfo.whereExpr);
8091+
stmt->setIteratorVar(iterator);
8092+
stmt->setIteratorVarRef(varRef);
8093+
8094+
return resultTarget;
8095+
}
8096+
79648097
Optional<SolutionApplicationTarget>
79658098
ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
79668099
auto &solution = Rewriter.solution;
@@ -7972,16 +8105,50 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
79728105
if (!rewrittenExpr)
79738106
return None;
79748107

7975-
/// Handle application for initializations.
7976-
if (target.getExprContextualTypePurpose() == CTP_Initialization) {
8108+
/// Handle special cases for expressions.
8109+
switch (target.getExprContextualTypePurpose()) {
8110+
case CTP_Initialization: {
79778111
auto initResultTarget = applySolutionToInitialization(
79788112
solution, target, rewrittenExpr);
79798113
if (!initResultTarget)
79808114
return None;
79818115

79828116
result = *initResultTarget;
7983-
} else {
8117+
break;
8118+
}
8119+
8120+
case CTP_ForEachStmt: {
8121+
auto forEachResultTarget = applySolutionToForEachStmt(
8122+
solution, target, rewrittenExpr);
8123+
if (!forEachResultTarget)
8124+
return None;
8125+
8126+
result = *forEachResultTarget;
8127+
break;
8128+
}
8129+
8130+
case CTP_Unused:
8131+
case CTP_ReturnStmt:
8132+
case swift::CTP_ReturnSingleExpr:
8133+
case swift::CTP_YieldByValue:
8134+
case swift::CTP_YieldByReference:
8135+
case swift::CTP_ThrowStmt:
8136+
case swift::CTP_EnumCaseRawValue:
8137+
case swift::CTP_DefaultParameter:
8138+
case swift::CTP_AutoclosureDefaultParameter:
8139+
case swift::CTP_CalleeResult:
8140+
case swift::CTP_CallArgument:
8141+
case swift::CTP_ClosureResult:
8142+
case swift::CTP_ArrayElement:
8143+
case swift::CTP_DictionaryKey:
8144+
case swift::CTP_DictionaryValue:
8145+
case swift::CTP_CoerceOperand:
8146+
case swift::CTP_AssignSource:
8147+
case swift::CTP_SubscriptAssignSource:
8148+
case swift::CTP_Condition:
8149+
case swift::CTP_CannotFail:
79848150
result.setExpr(rewrittenExpr);
8151+
break;
79858152
}
79868153
} else if (auto stmtCondition = target.getAsStmtCondition()) {
79878154
for (auto &condElement : *stmtCondition) {

lib/Sema/CSGen.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4128,6 +4128,112 @@ static bool generateInitPatternConstraints(
41284128
return false;
41294129
}
41304130

4131+
/// Generate constraints for a for-each statement.
4132+
static Optional<SolutionApplicationTarget>
4133+
generateForEachStmtConstraints(
4134+
ConstraintSystem &cs, SolutionApplicationTarget target, Expr *sequence) {
4135+
auto forEachStmtInfo = target.getForEachStmtInfo();
4136+
ForEachStmt *stmt = forEachStmtInfo.stmt;
4137+
4138+
auto locator = cs.getConstraintLocator(sequence);
4139+
auto contextualLocator =
4140+
cs.getConstraintLocator(sequence, LocatorPathElt::ContextualType());
4141+
4142+
// The expression type must conform to the Sequence protocol.
4143+
auto sequenceProto = TypeChecker::getProtocol(
4144+
cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence);
4145+
if (!sequenceProto) {
4146+
return None;
4147+
}
4148+
4149+
Type sequenceType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
4150+
cs.addConstraint(ConstraintKind::Conversion, cs.getType(sequence),
4151+
sequenceType, locator);
4152+
cs.addConstraint(ConstraintKind::ConformsTo, sequenceType,
4153+
sequenceProto->getDeclaredType(), contextualLocator);
4154+
4155+
// Check the element pattern.
4156+
ASTContext &ctx = cs.getASTContext();
4157+
auto dc = target.getDeclContext();
4158+
Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc,
4159+
/*isStmtCondition*/false);
4160+
if (!pattern)
4161+
return None;
4162+
4163+
auto contextualPattern =
4164+
ContextualPattern::forRawPattern(pattern, dc);
4165+
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);
4166+
if (patternType->hasError()) {
4167+
return None;
4168+
}
4169+
4170+
// Collect constraints from the element pattern.
4171+
auto elementLocator = cs.getConstraintLocator(
4172+
contextualLocator, ConstraintLocator::SequenceElementType);
4173+
Type initType = cs.generateConstraints(
4174+
pattern, contextualLocator, target.shouldBindPatternVarsOneWay(),
4175+
nullptr, 0);
4176+
if (!initType)
4177+
return None;
4178+
4179+
// Add a conversion constraint between the element type of the sequence
4180+
// and the type of the element pattern.
4181+
auto elementAssocType =
4182+
sequenceProto->getAssociatedType(cs.getASTContext().Id_Element);
4183+
Type elementType = DependentMemberType::get(sequenceType, elementAssocType);
4184+
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4185+
elementLocator);
4186+
4187+
// Determine the iterator type.
4188+
auto iteratorAssocType =
4189+
sequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator);
4190+
Type iteratorType = DependentMemberType::get(sequenceType, iteratorAssocType);
4191+
4192+
// The iterator type must conform to IteratorProtocol.
4193+
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
4194+
cs.getASTContext(), stmt->getForLoc(),
4195+
KnownProtocolKind::IteratorProtocol);
4196+
if (!iteratorProto)
4197+
return None;
4198+
4199+
// Reference the makeIterator witness.
4200+
FuncDecl *makeIterator = ctx.getSequenceMakeIterator();
4201+
Type makeIteratorType =
4202+
cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
4203+
cs.addValueWitnessConstraint(
4204+
LValueType::get(sequenceType), makeIterator,
4205+
makeIteratorType, dc, FunctionRefKind::Compound,
4206+
contextualLocator);
4207+
4208+
// Generate constraints for the "where" expression, if there is one.
4209+
if (forEachStmtInfo.whereExpr) {
4210+
auto *boolDecl = dc->getASTContext().getBoolDecl();
4211+
if (!boolDecl)
4212+
return None;
4213+
4214+
Type boolType = boolDecl->getDeclaredType();
4215+
if (!boolType)
4216+
return None;
4217+
4218+
SolutionApplicationTarget whereTarget(
4219+
forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType,
4220+
/*isDiscarded=*/false);
4221+
if (cs.generateConstraints(whereTarget, FreeTypeVariableBinding::Disallow))
4222+
return None;
4223+
4224+
forEachStmtInfo.whereExpr = whereTarget.getAsExpr();
4225+
}
4226+
4227+
// Populate all of the information for a for-each loop.
4228+
forEachStmtInfo.elementType = elementType;
4229+
forEachStmtInfo.iteratorType = iteratorType;
4230+
forEachStmtInfo.initType = initType;
4231+
forEachStmtInfo.sequenceType = sequenceType;
4232+
target.setPattern(pattern);
4233+
target.getForEachStmtInfo() = forEachStmtInfo;
4234+
return target;
4235+
}
4236+
41314237
bool ConstraintSystem::generateConstraints(
41324238
SolutionApplicationTarget &target,
41334239
FreeTypeVariableBinding allowFreeTypeVariables) {
@@ -4186,6 +4292,16 @@ bool ConstraintSystem::generateConstraints(
41864292
return true;
41874293
}
41884294

4295+
// For a for-each statement, generate constraints for the pattern, where
4296+
// clause, and sequence traversal.
4297+
if (target.getExprContextualTypePurpose() == CTP_ForEachStmt) {
4298+
auto resultTarget = generateForEachStmtConstraints(*this, target, expr);
4299+
if (!resultTarget)
4300+
return true;
4301+
4302+
target = *resultTarget;
4303+
}
4304+
41894305
if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
41904306
auto &log = getASTContext().TypeCheckerDebug->getStream();
41914307
log << "---Initial constraints for the given expression---\n";

lib/Sema/ConstraintSystem.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4161,8 +4161,8 @@ SolutionApplicationTarget::SolutionApplicationTarget(
41614161
expression.wrappedVar = nullptr;
41624162
expression.isDiscarded = isDiscarded;
41634163
expression.bindPatternVarsOneWay = false;
4164-
expression.patternBinding = nullptr;
4165-
expression.patternBindingIndex = 0;
4164+
expression.initialization.patternBinding = nullptr;
4165+
expression.initialization.patternBindingIndex = 0;
41664166
}
41674167

41684168
void SolutionApplicationTarget::maybeApplyPropertyWrapper() {
@@ -4259,18 +4259,35 @@ SolutionApplicationTarget SolutionApplicationTarget::forInitialization(
42594259
auto result = forInitialization(
42604260
initializer, dc, patternType,
42614261
patternBinding->getPattern(patternBindingIndex), bindPatternVarsOneWay);
4262-
result.expression.patternBinding = patternBinding;
4263-
result.expression.patternBindingIndex = patternBindingIndex;
4262+
result.expression.initialization.patternBinding = patternBinding;
4263+
result.expression.initialization.patternBindingIndex = patternBindingIndex;
42644264
return result;
42654265
}
42664266

4267+
SolutionApplicationTarget SolutionApplicationTarget::forForEachStmt(
4268+
ForEachStmt *stmt, ProtocolDecl *sequenceProto, DeclContext *dc,
4269+
bool bindPatternVarsOneWay) {
4270+
SolutionApplicationTarget target(
4271+
stmt->getSequence(), dc, CTP_ForEachStmt,
4272+
sequenceProto->getDeclaredType(), /*isDiscarded=*/false);
4273+
target.expression.pattern = stmt->getPattern();
4274+
target.expression.bindPatternVarsOneWay =
4275+
bindPatternVarsOneWay || (stmt->getWhere() != nullptr);
4276+
target.expression.forEachStmt.stmt = stmt;
4277+
target.expression.forEachStmt.whereExpr = stmt->getWhere();
4278+
return target;
4279+
}
4280+
42674281
ContextualPattern
4268-
SolutionApplicationTarget::getInitializationContextualPattern() const {
4282+
SolutionApplicationTarget::getContextualPattern() const {
42694283
assert(kind == Kind::expression);
4270-
assert(expression.contextualPurpose == CTP_Initialization);
4271-
if (expression.patternBinding) {
4284+
assert(expression.contextualPurpose == CTP_Initialization ||
4285+
expression.contextualPurpose == CTP_ForEachStmt);
4286+
if (expression.contextualPurpose == CTP_Initialization &&
4287+
expression.initialization.patternBinding) {
42724288
return ContextualPattern::forPatternBindingDecl(
4273-
expression.patternBinding, expression.patternBindingIndex);
4289+
expression.initialization.patternBinding,
4290+
expression.initialization.patternBindingIndex);
42744291
}
42754292

42764293
return ContextualPattern::forRawPattern(expression.pattern, expression.dc);

0 commit comments

Comments
 (0)