Skip to content

Commit baa1a73

Browse files
authored
Merge pull request swiftlang#28986 from DougGregor/for-each-pattern-cleanup
[Type checker] Move for-each pattern checking logic into the solver.
2 parents 1ee4403 + 2a13b1d commit baa1a73

File tree

6 files changed

+106
-144
lines changed

6 files changed

+106
-144
lines changed

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 92 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2892,13 +2892,15 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD,
28922892
return hadError;
28932893
}
28942894

2895-
auto TypeChecker::typeCheckForEachBinding(
2896-
DeclContext *dc, ForEachStmt *stmt) -> Optional<ForEachBinding> {
2895+
bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
28972896
/// Type checking listener for for-each binding.
28982897
class BindingListener : public ExprTypeCheckListener {
28992898
/// The for-each statement.
29002899
ForEachStmt *Stmt;
29012900

2901+
/// The declaration context in which this for-each statement resides.
2902+
DeclContext *DC;
2903+
29022904
/// The locator we're using.
29032905
ConstraintLocator *Locator;
29042906

@@ -2927,7 +2929,8 @@ auto TypeChecker::typeCheckForEachBinding(
29272929
Type IteratorType;
29282930

29292931
public:
2930-
explicit BindingListener(ForEachStmt *stmt) : Stmt(stmt) { }
2932+
explicit BindingListener(ForEachStmt *stmt, DeclContext *dc)
2933+
: Stmt(stmt), DC(dc) { }
29312934

29322935
bool builtConstraints(ConstraintSystem &cs, Expr *expr) override {
29332936
// Save the locator we're using for the expression.
@@ -2957,6 +2960,25 @@ auto TypeChecker::typeCheckForEachBinding(
29572960
auto elementLocator = cs.getConstraintLocator(
29582961
ContextualLocator, ConstraintLocator::SequenceElementType);
29592962

2963+
// Check the element pattern.
2964+
ASTContext &ctx = cs.getASTContext();
2965+
if (auto *P = TypeChecker::resolvePattern(Stmt->getPattern(), DC,
2966+
/*isStmtCondition*/false)) {
2967+
Stmt->setPattern(P);
2968+
} else {
2969+
Stmt->getPattern()->setType(ErrorType::get(ctx));
2970+
return true;
2971+
}
2972+
2973+
TypeResolutionOptions options(TypeResolverContext::InExpression);
2974+
options |= TypeResolutionFlags::AllowUnspecifiedTypes;
2975+
options |= TypeResolutionFlags::AllowUnboundGenerics;
2976+
if (TypeChecker::typeCheckPattern(Stmt->getPattern(), DC, options)) {
2977+
// FIXME: Handle errors better.
2978+
Stmt->getPattern()->setType(ErrorType::get(ctx));
2979+
return true;
2980+
}
2981+
29602982
// Collect constraints from the element pattern.
29612983
auto pattern = Stmt->getPattern();
29622984
InitType = cs.generateConstraints(pattern, elementLocator);
@@ -2985,13 +3007,12 @@ auto TypeChecker::typeCheckForEachBinding(
29853007
}
29863008

29873009
// Reference the makeIterator witness.
2988-
ASTContext &ctx = cs.getASTContext();
29893010
FuncDecl *makeIterator = ctx.getSequenceMakeIterator();
29903011
Type makeIteratorType =
29913012
cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
29923013
cs.addValueWitnessConstraint(
29933014
LValueType::get(SequenceType), makeIterator,
2994-
makeIteratorType, cs.DC, FunctionRefKind::Compound,
3015+
makeIteratorType, DC, FunctionRefKind::Compound,
29953016
ContextualLocator);
29963017

29973018
Stmt->setSequence(expr);
@@ -3001,6 +3022,7 @@ auto TypeChecker::typeCheckForEachBinding(
30013022
Expr *appliedSolution(Solution &solution, Expr *expr) override {
30023023
// Figure out what types the constraints decided on.
30033024
auto &cs = solution.getConstraintSystem();
3025+
ASTContext &ctx = cs.getASTContext();
30043026
InitType = solution.simplifyType(InitType);
30053027
SequenceType = solution.simplifyType(SequenceType);
30063028
ElementType = solution.simplifyType(ElementType);
@@ -3012,17 +3034,17 @@ auto TypeChecker::typeCheckForEachBinding(
30123034

30133035
cs.cacheExprTypes(expr);
30143036
Stmt->setSequence(expr);
3037+
solution.setExprTypes(expr);
30153038

30163039
// Apply the solution to the iteration pattern as well.
30173040
Pattern *pattern = Stmt->getPattern();
30183041
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
30193042
options |= TypeResolutionFlags::OverrideType;
30203043
if (TypeChecker::coercePatternToType(pattern,
3021-
TypeResolution::forContextual(cs.DC),
3044+
TypeResolution::forContextual(DC),
30223045
InitType, options)) {
30233046
return nullptr;
30243047
}
3025-
30263048
Stmt->setPattern(pattern);
30273049

30283050
// Get the conformance of the sequence type to the Sequence protocol.
@@ -3032,24 +3054,79 @@ auto TypeChecker::typeCheckForEachBinding(
30323054
"Couldn't find sequence conformance");
30333055
Stmt->setSequenceConformance(SequenceConformance);
30343056

3035-
solution.setExprTypes(expr);
3036-
return expr;
3037-
}
3057+
// Check the filtering condition.
3058+
// FIXME: This should be pulled into the constraint system itself.
3059+
if (auto *Where = Stmt->getWhere()) {
3060+
if (!TypeChecker::typeCheckCondition(Where, DC))
3061+
Stmt->setWhere(Where);
3062+
}
30383063

3039-
ForEachBinding getBinding() const {
3040-
return { SequenceType, SequenceConformance, IteratorType, ElementType };
3064+
// Invoke iterator() to get an iterator from the sequence.
3065+
VarDecl *iterator;
3066+
Type nextResultType = OptionalType::get(ElementType);
3067+
{
3068+
// Create a local variable to capture the iterator.
3069+
std::string name;
3070+
if (auto np = dyn_cast_or_null<NamedPattern>(Stmt->getPattern()))
3071+
name = "$"+np->getBoundName().str().str();
3072+
name += "$generator";
3073+
3074+
iterator = new (ctx) VarDecl(
3075+
/*IsStatic*/ false, VarDecl::Introducer::Var,
3076+
/*IsCaptureList*/ false, Stmt->getInLoc(),
3077+
ctx.getIdentifier(name), DC);
3078+
iterator->setInterfaceType(IteratorType->mapTypeOutOfContext());
3079+
iterator->setImplicit();
3080+
Stmt->setIteratorVar(iterator);
3081+
3082+
auto genPat = new (ctx) NamedPattern(iterator);
3083+
genPat->setImplicit();
3084+
3085+
// TODO: test/DebugInfo/iteration.swift requires this extra info to
3086+
// be around.
3087+
PatternBindingDecl::createImplicit(
3088+
ctx, StaticSpellingKind::None, genPat,
3089+
new (ctx) OpaqueValueExpr(Stmt->getInLoc(), nextResultType),
3090+
DC, /*VarLoc*/ Stmt->getForLoc());
3091+
}
3092+
3093+
// Create the iterator variable.
3094+
auto *varRef = TypeChecker::buildCheckedRefExpr(
3095+
iterator, DC, DeclNameLoc(Stmt->getInLoc()), /*implicit*/ true);
3096+
if (varRef)
3097+
Stmt->setIteratorVarRef(varRef);
3098+
3099+
// Convert that Optional<Element> value to the type of the pattern.
3100+
auto optPatternType = OptionalType::get(Stmt->getPattern()->getType());
3101+
if (!optPatternType->isEqual(nextResultType)) {
3102+
OpaqueValueExpr *elementExpr =
3103+
new (ctx) OpaqueValueExpr(Stmt->getInLoc(), nextResultType,
3104+
/*isPlaceholder=*/true);
3105+
Expr *convertElementExpr = elementExpr;
3106+
if (TypeChecker::typeCheckExpression(
3107+
convertElementExpr, DC,
3108+
TypeLoc::withoutLoc(optPatternType),
3109+
CTP_CoerceOperand).isNull()) {
3110+
return nullptr;
3111+
}
3112+
elementExpr->setIsPlaceholder(false);
3113+
Stmt->setElementExpr(elementExpr);
3114+
Stmt->setConvertElementExpr(convertElementExpr);
3115+
}
3116+
3117+
return expr;
30413118
}
30423119
};
30433120

3044-
BindingListener listener(stmt);
3121+
BindingListener listener(stmt, dc);
30453122
Expr *seq = stmt->getSequence();
30463123
assert(seq && "type-checking an uninitialized for-each statement?");
30473124

30483125
// Type-check the for-each loop sequence and element pattern.
30493126
auto resultTy = TypeChecker::typeCheckExpression(seq, dc, &listener);
30503127
if (!resultTy)
3051-
return None;
3052-
return listener.getBinding();
3128+
return true;
3129+
return false;
30533130
}
30543131

30553132
bool TypeChecker::typeCheckCondition(Expr *&expr, DeclContext *dc) {

lib/Sema/TypeCheckPattern.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -741,23 +741,7 @@ bool TypeChecker::typeCheckPattern(Pattern *P, DeclContext *dc,
741741
case PatternKind::Typed: {
742742
auto resolution = TypeResolution::forContextual(dc);
743743
TypedPattern *TP = cast<TypedPattern>(P);
744-
bool hadError = validateTypedPattern(resolution, TP, options);
745-
746-
// If we have unbound generic types, don't apply them below; instead,
747-
// the caller will call typeCheckBinding() later.
748-
if (P->getType()->hasUnboundGenericType())
749-
return hadError;
750-
751-
Pattern *subPattern = TP->getSubPattern();
752-
if (TypeChecker::coercePatternToType(subPattern, resolution, P->getType(),
753-
options|TypeResolutionFlags::FromNonInferredPattern,
754-
TP->getTypeLoc()))
755-
hadError = true;
756-
else {
757-
TP->setSubPattern(subPattern);
758-
TP->setType(subPattern->getType());
759-
}
760-
return hadError;
744+
return validateTypedPattern(resolution, TP, options);
761745
}
762746

763747
// A wildcard or name pattern cannot appear by itself in a context

lib/Sema/TypeCheckStmt.cpp

Lines changed: 1 addition & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -731,108 +731,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
731731
}
732732

733733
Stmt *visitForEachStmt(ForEachStmt *S) {
734-
TypeResolutionOptions options(TypeResolverContext::InExpression);
735-
options |= TypeResolutionFlags::AllowUnspecifiedTypes;
736-
options |= TypeResolutionFlags::AllowUnboundGenerics;
737-
738-
if (auto *P = TypeChecker::resolvePattern(S->getPattern(), DC,
739-
/*isStmtCondition*/false)) {
740-
S->setPattern(P);
741-
} else {
742-
S->getPattern()->setType(ErrorType::get(getASTContext()));
743-
return nullptr;
744-
}
745-
746-
if (TypeChecker::typeCheckPattern(S->getPattern(), DC, options)) {
747-
// FIXME: Handle errors better.
748-
S->getPattern()->setType(ErrorType::get(getASTContext()));
749-
return nullptr;
750-
}
751-
752-
auto binding = TypeChecker::typeCheckForEachBinding(DC, S);
753-
if (!binding)
754-
return nullptr;
755-
756-
if (auto *Where = S->getWhere()) {
757-
if (TypeChecker::typeCheckCondition(Where, DC))
758-
return nullptr;
759-
S->setWhere(Where);
760-
}
761-
762-
763-
// Retrieve the 'Sequence' protocol.
764-
ProtocolDecl *sequenceProto = TypeChecker::getProtocol(
765-
getASTContext(), S->getForLoc(), KnownProtocolKind::Sequence);
766-
if (!sequenceProto) {
767-
return nullptr;
768-
}
769-
770-
// Retrieve the 'Iterator' protocol.
771-
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
772-
getASTContext(), S->getForLoc(), KnownProtocolKind::IteratorProtocol);
773-
if (!iteratorProto) {
734+
if (TypeChecker::typeCheckForEachBinding(DC, S))
774735
return nullptr;
775-
}
776-
777-
// Invoke iterator() to get an iterator from the sequence.
778-
Type iteratorTy = binding->iteratorType;
779-
VarDecl *iterator;
780-
Type nextResultType = OptionalType::get(binding->elementType);
781-
{
782-
// Create a local variable to capture the iterator.
783-
std::string name;
784-
if (auto np = dyn_cast_or_null<NamedPattern>(S->getPattern()))
785-
name = "$"+np->getBoundName().str().str();
786-
name += "$generator";
787-
788-
iterator = new (getASTContext()) VarDecl(
789-
/*IsStatic*/ false, VarDecl::Introducer::Var,
790-
/*IsCaptureList*/ false, S->getInLoc(),
791-
getASTContext().getIdentifier(name), DC);
792-
iterator->setInterfaceType(iteratorTy->mapTypeOutOfContext());
793-
iterator->setImplicit();
794-
S->setIteratorVar(iterator);
795-
796-
auto genPat = new (getASTContext()) NamedPattern(iterator);
797-
genPat->setImplicit();
798-
799-
// TODO: test/DebugInfo/iteration.swift requires this extra info to
800-
// be around.
801-
PatternBindingDecl::createImplicit(
802-
getASTContext(), StaticSpellingKind::None, genPat,
803-
new (getASTContext()) OpaqueValueExpr(S->getInLoc(), nextResultType),
804-
DC, /*VarLoc*/ S->getForLoc());
805-
}
806-
807-
// Working with iterators requires Optional.
808-
if (TypeChecker::requireOptionalIntrinsics(getASTContext(), S->getForLoc()))
809-
return nullptr;
810-
811-
// Create the iterator variable.
812-
auto *varRef = TypeChecker::buildCheckedRefExpr(iterator, DC,
813-
DeclNameLoc(S->getInLoc()),
814-
/*implicit*/ true);
815-
if (!varRef)
816-
return nullptr;
817-
S->setIteratorVarRef(varRef);
818-
819-
// Convert that Optional<Element> value to the type of the pattern.
820-
auto optPatternType = OptionalType::get(S->getPattern()->getType());
821-
if (!optPatternType->isEqual(nextResultType)) {
822-
OpaqueValueExpr *elementExpr =
823-
new (getASTContext()) OpaqueValueExpr(S->getInLoc(), nextResultType,
824-
/*isPlaceholder=*/true);
825-
Expr *convertElementExpr = elementExpr;
826-
if (TypeChecker::typeCheckExpression(
827-
convertElementExpr, DC,
828-
TypeLoc::withoutLoc(optPatternType),
829-
CTP_CoerceOperand).isNull()) {
830-
return nullptr;
831-
}
832-
elementExpr->setIsPlaceholder(false);
833-
S->setElementExpr(elementExpr);
834-
S->setConvertElementExpr(convertElementExpr);
835-
}
836736

837737
// Type-check the body of the loop.
838738
AddLabeledStmt loopNest(*this, S);

lib/Sema/TypeCheckStorage.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,15 @@ PatternBindingEntryRequest::evaluate(Evaluator &eval,
251251
binding->diagnose(diag::inferred_opaque_type,
252252
binding->getInit(entryNumber)->getType());
253253
}
254+
} else {
255+
// Coerce the pattern to the computed type.
256+
auto resolution = TypeResolution::forContextual(binding->getDeclContext());
257+
if (TypeChecker::coercePatternToType(pattern, resolution,
258+
pattern->getType(), options)) {
259+
binding->setInvalid();
260+
pattern->setType(ErrorType::get(Context));
261+
return &pbe;
262+
}
254263
}
255264

256265
// If the pattern binding appears in a type or library file context, then

lib/Sema/TypeChecker.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,19 +1032,10 @@ class TypeChecker final {
10321032
static bool typeCheckBinding(Pattern *&P, Expr *&Init, DeclContext *DC);
10331033
static bool typeCheckPatternBinding(PatternBindingDecl *PBD, unsigned patternNumber);
10341034

1035-
/// Information about a type-checked for-each binding.
1036-
struct ForEachBinding {
1037-
Type sequenceType;
1038-
ProtocolConformanceRef sequenceConformance;
1039-
Type iteratorType;
1040-
Type elementType;
1041-
};
1042-
10431035
/// Type-check a for-each loop's pattern binding and sequence together.
10441036
///
1045-
/// \returns the binding, if successful.
1046-
static Optional<ForEachBinding> typeCheckForEachBinding(
1047-
DeclContext *dc, ForEachStmt *stmt);
1037+
/// \returns true if a failure occurred.
1038+
static bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt);
10481039

10491040
/// Compute the set of captures for the given function or closure.
10501041
static void computeCaptures(AnyFunctionRef AFR);

test/decl/var/variables.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ func test21057425() -> (Int, Int) {
107107

108108
// rdar://problem/21081340
109109
func test21081340() {
110+
func foo() { }
110111
let (x: a, y: b): () = foo() // expected-error{{tuple pattern has the wrong length for tuple type '()'}}
111112
}
112113

0 commit comments

Comments
 (0)