Skip to content

Commit 49a63fc

Browse files
authored
Merge pull request #68787 from xedin/separate-solving-for-for-loop-condition
[ConstraintSystem] Solve `where` clauses of for-in loops separately
2 parents 8ca4314 + 3ca54ae commit 49a63fc

File tree

8 files changed

+74
-29
lines changed

8 files changed

+74
-29
lines changed

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class SyntacticElementTarget {
147147
ForEachStmt *stmt;
148148
DeclContext *dc;
149149
Pattern *pattern;
150-
bool bindPatternVarsOneWay;
150+
bool ignoreWhereClause;
151151
ForEachStmtInfo info;
152152
} forEachStmt;
153153

@@ -227,11 +227,11 @@ class SyntacticElementTarget {
227227
}
228228

229229
SyntacticElementTarget(ForEachStmt *stmt, DeclContext *dc,
230-
bool bindPatternVarsOneWay)
230+
bool ignoreWhereClause)
231231
: kind(Kind::forEachStmt) {
232232
forEachStmt.stmt = stmt;
233233
forEachStmt.dc = dc;
234-
forEachStmt.bindPatternVarsOneWay = bindPatternVarsOneWay;
234+
forEachStmt.ignoreWhereClause = ignoreWhereClause;
235235
}
236236

237237
/// Form a target for the initialization of a pattern from an expression.
@@ -249,7 +249,7 @@ class SyntacticElementTarget {
249249
/// Form a target for a for-in loop.
250250
static SyntacticElementTarget forForEachStmt(ForEachStmt *stmt,
251251
DeclContext *dc,
252-
bool bindPatternVarsOneWay);
252+
bool ignoreWhereClause = false);
253253

254254
/// Form a target for a property with an attached property wrapper that is
255255
/// initialized out-of-line.
@@ -469,10 +469,6 @@ class SyntacticElementTarget {
469469
bool shouldBindPatternVarsOneWay() const {
470470
if (kind == Kind::expression)
471471
return expression.bindPatternVarsOneWay;
472-
473-
if (kind == Kind::forEachStmt)
474-
return forEachStmt.bindPatternVarsOneWay;
475-
476472
return false;
477473
}
478474

@@ -521,6 +517,11 @@ class SyntacticElementTarget {
521517
return expression.initialization.patternBindingIndex;
522518
}
523519

520+
bool ignoreForEachWhereClause() const {
521+
assert(isForEachStmt());
522+
return forEachStmt.ignoreWhereClause;
523+
}
524+
524525
const ForEachStmtInfo &getForEachStmtInfo() const {
525526
assert(isForEachStmt());
526527
return forEachStmt.info;

lib/Sema/CSGen.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4579,22 +4579,16 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45794579
/*flags=*/0);
45804580
{
45814581
auto nextType = cs.getType(forEachStmtInfo.nextCall);
4582-
// Note that `OptionalObject` is not used here. This is due to inference
4583-
// behavior where it would bind `elementType` to the `initType` before
4584-
// resolving `optional object` constraint which is sometimes too eager.
4585-
cs.addConstraint(ConstraintKind::Conversion, nextType,
4586-
OptionalType::get(elementType), elementTypeLoc);
4582+
cs.addConstraint(ConstraintKind::OptionalObject, nextType, elementType,
4583+
elementTypeLoc);
45874584
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
45884585
elementLocator);
45894586
}
45904587

45914588
// Generate constraints for the "where" expression, if there is one.
4592-
if (auto *whereExpr = stmt->getWhere()) {
4593-
auto *boolDecl = dc->getASTContext().getBoolDecl();
4594-
if (!boolDecl)
4595-
return llvm::None;
4596-
4597-
Type boolType = boolDecl->getDeclaredInterfaceType();
4589+
auto *whereExpr = stmt->getWhere();
4590+
if (whereExpr && !target.ignoreForEachWhereClause()) {
4591+
Type boolType = dc->getASTContext().getBoolType();
45984592
if (!boolType)
45994593
return llvm::None;
46004594

lib/Sema/CSSimplify.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6303,10 +6303,22 @@ bool ConstraintSystem::repairFailures(
63036303
}
63046304

63056305
case ConstraintLocator::OptionalPayload: {
6306+
if (lhs->isPlaceholder() || rhs->isPlaceholder())
6307+
return true;
6308+
63066309
if (repairViaOptionalUnwrap(*this, lhs, rhs, matchKind, conversionsOrFixes,
63076310
locator))
63086311
return true;
63096312

6313+
if (path.size() > 1) {
6314+
path.pop_back();
6315+
if (path.back().is<LocatorPathElt::SequenceElementType>()) {
6316+
conversionsOrFixes.push_back(
6317+
CollectionElementContextualMismatch::create(
6318+
*this, lhs, rhs, getConstraintLocator(anchor, path)));
6319+
return true;
6320+
}
6321+
}
63106322
break;
63116323
}
63126324

lib/Sema/CSSyntacticElement.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,13 @@ class SyntacticElementConstraintGenerator
629629
///
630630
/// - From sequence to pattern, when pattern has no type information.
631631
void visitForEachPattern(Pattern *pattern, ForEachStmt *forEachStmt) {
632+
// The `where` clause should be ignored because \c visitForEachStmt
633+
// records it as a separate conjunction element to allow for a more
634+
// granular control over what contextual information is brought into
635+
// the scope during pattern + sequence and `where` clause solving.
632636
auto target = SyntacticElementTarget::forForEachStmt(
633637
forEachStmt, context.getAsDeclContext(),
634-
/*bindTypeVarsOneWay=*/false);
638+
/*ignoreWhereClause=*/true);
635639

636640
if (cs.generateConstraints(target)) {
637641
hadError = true;
@@ -961,10 +965,24 @@ class SyntacticElementConstraintGenerator
961965

962966
// For-each pattern.
963967
//
964-
// Note that we don't record a sequence or where clause here,
965-
// they would be handled together with pattern because pattern can
966-
// inform a type of sequence element e.g. `for i: Int8 in 0 ..< 8`
968+
// Note that we don't record a sequence here, it would be handled
969+
// together with pattern because pattern can inform a type of sequence
970+
// element e.g. `for i: Int8 in 0 ..< 8`
967971
elements.push_back(makeElement(forEachStmt->getPattern(), stmtLoc));
972+
973+
// Where clause if any.
974+
if (auto *where = forEachStmt->getWhere()) {
975+
Type boolType = cs.getASTContext().getBoolType();
976+
if (!boolType) {
977+
hadError = true;
978+
return;
979+
}
980+
981+
ContextualTypeInfo context(boolType, CTP_Condition);
982+
elements.push_back(
983+
makeElement(where, stmtLoc, context, /*isDiscarded=*/false));
984+
}
985+
968986
// Body of the `for-in` loop.
969987
elements.push_back(makeElement(forEachStmt->getBody(), stmtLoc));
970988

lib/Sema/SyntacticElementTarget.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,8 @@ SyntacticElementTarget SyntacticElementTarget::forInitialization(
178178

179179
SyntacticElementTarget
180180
SyntacticElementTarget::forForEachStmt(ForEachStmt *stmt, DeclContext *dc,
181-
bool bindPatternVarsOneWay) {
182-
SyntacticElementTarget target(
183-
stmt, dc, bindPatternVarsOneWay || bool(stmt->getWhere()));
181+
bool ignoreWhereClause) {
182+
SyntacticElementTarget target(stmt, dc, ignoreWhereClause);
184183
return target;
185184
}
186185

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,8 +901,7 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
901901
return true;
902902
};
903903

904-
auto target = SyntacticElementTarget::forForEachStmt(
905-
stmt, dc, /*bindPatternVarsOneWay=*/false);
904+
auto target = SyntacticElementTarget::forForEachStmt(stmt, dc);
906905
if (!typeCheckTarget(target))
907906
return failed();
908907

test/Constraints/rdar107651291.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ func foo(xs: [String: [String]], ys: [String: [String]]) {
77
for (a, b) in zip(xs, ys) {}
88
// expected-error@-1 {{type 'Dictionary<String, [String]>.Element' (aka '(key: String, value: Array<String>)') cannot conform to 'Sequence'}}
99
// expected-note@-2 {{only concrete types such as structs, enums and classes can conform to protocols}}
10-
// expected-note@-3 {{required by referencing instance method 'next()'}}
10+
// expected-note@-3 {{required by global function 'zip' where 'Sequence2' = 'Dictionary<String, [String]>.Element' (aka '(key: String, value: Array<String>)')}}
1111
}
1212
}

test/stmt/foreach.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,25 @@ do {
273273
// https://github.com/apple/swift/issues/65650 - Make sure we say 'String', not 'Any'.
274274
for (x, y) in [""] {} // expected-error {{tuple pattern cannot match values of non-tuple type 'String'}}
275275
}
276+
277+
do {
278+
class Base : Hashable {
279+
static func ==(_: Base, _: Base) -> Bool { false }
280+
281+
func hash(into hasher: inout Hasher) {}
282+
}
283+
284+
class Child : Base {
285+
var value: Int = 0
286+
}
287+
288+
struct Range {
289+
func contains(_: Base) -> Bool { return false }
290+
}
291+
292+
func test(data: Set<Child>, _ range: Range) {
293+
for v in data where range.contains(v) {
294+
_ = v.value // Ok (`v` is inferred from `data` and not from `range`)
295+
}
296+
}
297+
}

0 commit comments

Comments
 (0)