Skip to content

Commit c5a655e

Browse files
committed
[Type checker] Fold more for-each type checking into the constraint solver.
The type checking of the for-each loop is split between the constraint solver (which does most of the work) and the statement checker (which updates the for-each loop AST). Move more of the work into the constraint solver proper, so that the AST updates can happen in one place, making use of the solution produced by the solver. This allows a few things, some of which are short-term gains and others that are more future-facing: * `TypeChecker::convertToType` has been removed, because we can now either use the more general `typeCheckExpression` entry point or perform the appropriate operation within the constraint system. * Solving the constraint system ensures that everything related to the for-each loop full checks out * Additional refactoring will make it easier for the for-each loop to be checked as part of a larger constraint system, e.g., for processing entire closures or function bodies (that’s the futurist bit).
1 parent 746b58e commit c5a655e

File tree

5 files changed

+139
-162
lines changed

5 files changed

+139
-162
lines changed

include/swift/AST/Expr.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3924,6 +3924,10 @@ class OpaqueValueExpr : public Expr {
39243924
/// value to be specified later.
39253925
bool isPlaceholder() const { return Bits.OpaqueValueExpr.IsPlaceholder; }
39263926

3927+
void setIsPlaceholder(bool value) {
3928+
Bits.OpaqueValueExpr.IsPlaceholder = value;
3929+
}
3930+
39273931
SourceRange getSourceRange() const { return Range; }
39283932

39293933
static bool classof(const Expr *E) {

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 108 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2892,7 +2892,8 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD,
28922892
return hadError;
28932893
}
28942894

2895-
bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
2895+
auto TypeChecker::typeCheckForEachBinding(
2896+
DeclContext *dc, ForEachStmt *stmt) -> Optional<ForEachBinding> {
28962897
/// Type checking listener for for-each binding.
28972898
class BindingListener : public ExprTypeCheckListener {
28982899
/// The for-each statement.
@@ -2901,45 +2902,66 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
29012902
/// The locator we're using.
29022903
ConstraintLocator *Locator;
29032904

2905+
/// The contextual locator we're using.
2906+
ConstraintLocator *ContextualLocator;
2907+
2908+
/// The Sequence protocol.
2909+
ProtocolDecl *SequenceProto;
2910+
2911+
/// The IteratorProtocol.
2912+
ProtocolDecl *IteratorProto;
2913+
29042914
/// The type of the initializer.
29052915
Type InitType;
29062916

29072917
/// The type of the sequence.
29082918
Type SequenceType;
29092919

2920+
/// The conformance of the sequence type to the Sequence protocol.
2921+
ProtocolConformanceRef SequenceConformance;
2922+
2923+
/// The type of the element.
2924+
Type ElementType;
2925+
2926+
/// The type of the iterator.
2927+
Type IteratorType;
2928+
2929+
/// The conformance of the iterator type to IteratorProtocol.
2930+
ProtocolConformanceRef IteratorConformance;
2931+
2932+
/// The type of makeIterator.
2933+
Type MakeIteratorType;
2934+
29102935
public:
29112936
explicit BindingListener(ForEachStmt *stmt) : Stmt(stmt) { }
29122937

29132938
bool builtConstraints(ConstraintSystem &cs, Expr *expr) override {
29142939
// Save the locator we're using for the expression.
29152940
Locator = cs.getConstraintLocator(expr);
2916-
auto *contextualLocator =
2941+
ContextualLocator =
29172942
cs.getConstraintLocator(expr, LocatorPathElt::ContextualType());
29182943

2919-
// The expression type must conform to the Sequence.
2920-
ProtocolDecl *sequenceProto = TypeChecker::getProtocol(
2944+
// The expression type must conform to the Sequence protocol.
2945+
SequenceProto = TypeChecker::getProtocol(
29212946
cs.getASTContext(), Stmt->getForLoc(), KnownProtocolKind::Sequence);
2922-
if (!sequenceProto) {
2947+
if (!SequenceProto) {
29232948
return true;
29242949
}
29252950

2926-
auto elementAssocType =
2927-
sequenceProto->getAssociatedType(cs.getASTContext().Id_Element);
2928-
29292951
SequenceType = cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
29302952
cs.addConstraint(ConstraintKind::Conversion, cs.getType(expr),
29312953
SequenceType, Locator);
29322954
cs.addConstraint(ConstraintKind::ConformsTo, SequenceType,
2933-
sequenceProto->getDeclaredType(), contextualLocator);
2955+
SequenceProto->getDeclaredType(), ContextualLocator);
29342956

29352957
// Since we are using "contextual type" here, it has to be recorded
29362958
// in the constraint system for diagnostics to have access to "purpose".
29372959
cs.setContextualType(
2938-
expr, TypeLoc::withoutLoc(sequenceProto->getDeclaredType()),
2960+
expr, TypeLoc::withoutLoc(SequenceProto->getDeclaredType()),
29392961
CTP_ForEachStmt);
29402962

29412963
auto elementLocator = cs.getConstraintLocator(
2942-
contextualLocator, ConstraintLocator::SequenceElementType);
2964+
ContextualLocator, ConstraintLocator::SequenceElementType);
29432965

29442966
// Collect constraints from the element pattern.
29452967
auto pattern = Stmt->getPattern();
@@ -2949,10 +2971,36 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
29492971

29502972
// Add a conversion constraint between the element type of the sequence
29512973
// and the type of the element pattern.
2952-
auto elementType = DependentMemberType::get(SequenceType, elementAssocType);
2953-
cs.addConstraint(ConstraintKind::Conversion, elementType, InitType,
2974+
auto elementAssocType =
2975+
SequenceProto->getAssociatedType(cs.getASTContext().Id_Element);
2976+
ElementType = DependentMemberType::get(SequenceType, elementAssocType);
2977+
cs.addConstraint(ConstraintKind::Conversion, ElementType, InitType,
29542978
elementLocator);
29552979

2980+
// Determine the iterator type.
2981+
auto iteratorAssocType =
2982+
SequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator);
2983+
IteratorType = DependentMemberType::get(SequenceType, iteratorAssocType);
2984+
2985+
// The iterator type must conform to IteratorProtocol.
2986+
IteratorProto = TypeChecker::getProtocol(
2987+
cs.getASTContext(), Stmt->getForLoc(),
2988+
KnownProtocolKind::IteratorProtocol);
2989+
if (!IteratorProto) {
2990+
return true;
2991+
}
2992+
2993+
// Reference the makeIterator witness.
2994+
// FIXME: Not tied to the actual witness.
2995+
ASTContext &ctx = cs.getASTContext();
2996+
DeclName makeIteratorName(ctx, ctx.Id_makeIterator,
2997+
ArrayRef<Identifier>());
2998+
MakeIteratorType = cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
2999+
cs.addValueMemberConstraint(
3000+
LValueType::get(SequenceType), DeclNameRef(makeIteratorName),
3001+
MakeIteratorType, cs.DC, FunctionRefKind::Compound, { },
3002+
ContextualLocator);
3003+
29563004
Stmt->setSequence(expr);
29573005
return false;
29583006
}
@@ -2962,13 +3010,24 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
29623010
auto &cs = solution.getConstraintSystem();
29633011
InitType = solution.simplifyType(InitType);
29643012
SequenceType = solution.simplifyType(SequenceType);
3013+
ElementType = solution.simplifyType(ElementType);
3014+
IteratorType = solution.simplifyType(IteratorType);
29653015

29663016
// Perform any necessary conversions of the sequence (e.g. [T]! -> [T]).
2967-
expr = solution.coerceToType(expr, SequenceType, cs.getConstraintLocator(expr));
3017+
expr = solution.coerceToType(expr, SequenceType, Locator);
29683018

29693019
if (!expr) return nullptr;
29703020

3021+
// Convert the sequence as appropriate for the makeIterator() call.
3022+
auto makeIteratorOverload = solution.getOverloadChoice(ContextualLocator);
3023+
auto makeIteratorSelfType = solution.simplifyType(
3024+
makeIteratorOverload.openedFullType
3025+
)->castTo<AnyFunctionType>()->getParams()[0].getPlainType();
3026+
expr = solution.coerceToType(expr, makeIteratorSelfType,
3027+
ContextualLocator);
3028+
29713029
cs.cacheExprTypes(expr);
3030+
Stmt->setSequence(expr);
29723031

29733032
// Apply the solution to the iteration pattern as well.
29743033
Pattern *pattern = Stmt->getPattern();
@@ -2981,11 +3040,42 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
29813040
}
29823041

29833042
Stmt->setPattern(pattern);
2984-
Stmt->setSequence(expr);
3043+
3044+
// Get the conformance of the sequence type to the Sequence protocol.
3045+
// FIXME: Get this from the solution and substitute into that.
3046+
SequenceConformance = TypeChecker::conformsToProtocol(
3047+
SequenceType, SequenceProto, cs.DC,
3048+
ConformanceCheckFlags::InExpression,
3049+
expr->getLoc());
3050+
assert(!SequenceConformance.isInvalid() &&
3051+
"Couldn't find sequence conformance");
3052+
Stmt->setSequenceConformance(SequenceConformance);
3053+
3054+
// Retrieve the conformance of the iterator type to IteratorProtocol.
3055+
// FIXME: Get this from the solution and substitute into that.
3056+
IteratorConformance = TypeChecker::conformsToProtocol(
3057+
IteratorType, IteratorProto, cs.DC,
3058+
ConformanceCheckFlags::InExpression,
3059+
expr->getLoc());
3060+
3061+
// Record the makeIterator declaration we used.
3062+
auto makeIteratorDecl = makeIteratorOverload.choice.getDecl();
3063+
auto makeIteratorSubs = SequenceType->getMemberSubstitutionMap(
3064+
cs.DC->getParentModule(), makeIteratorDecl);
3065+
auto makeIteratorDeclRef =
3066+
ConcreteDeclRef(makeIteratorDecl, makeIteratorSubs);
3067+
Stmt->setMakeIterator(makeIteratorDeclRef);
29853068

29863069
solution.setExprTypes(expr);
29873070
return expr;
29883071
}
3072+
3073+
ForEachBinding getBinding() const {
3074+
return {
3075+
SequenceType, SequenceConformance, IteratorType, IteratorConformance,
3076+
ElementType
3077+
};
3078+
}
29893079
};
29903080

29913081
BindingListener listener(stmt);
@@ -2994,7 +3084,9 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
29943084

29953085
// Type-check the for-each loop sequence and element pattern.
29963086
auto resultTy = TypeChecker::typeCheckExpression(seq, dc, &listener);
2997-
return !resultTy;
3087+
if (!resultTy)
3088+
return None;
3089+
return listener.getBinding();
29983090
}
29993091

30003092
bool TypeChecker::typeCheckCondition(Expr *&expr, DeclContext *dc) {
@@ -3427,88 +3519,6 @@ TypeChecker::coerceToRValue(ASTContext &Context, Expr *expr,
34273519
return expr;
34283520
}
34293521

3430-
bool TypeChecker::convertToType(Expr *&expr, Type type, DeclContext *dc,
3431-
Optional<Pattern*> typeFromPattern) {
3432-
// TODO: need to add kind arg?
3433-
// Construct a constraint system from this expression.
3434-
ConstraintSystem cs(dc, ConstraintSystemFlags::AllowFixes);
3435-
3436-
// Cache the expression type on the system to ensure it is available
3437-
// on diagnostics if the convertion fails.
3438-
cs.cacheExprTypes(expr);
3439-
3440-
// If there is a type that we're expected to convert to, add the conversion
3441-
// constraint.
3442-
cs.addConstraint(ConstraintKind::Conversion, expr->getType(), type,
3443-
cs.getConstraintLocator(expr));
3444-
3445-
auto &Context = dc->getASTContext();
3446-
if (Context.TypeCheckerOpts.DebugConstraintSolver) {
3447-
auto &log = Context.TypeCheckerDebug->getStream();
3448-
log << "---Initial constraints for the given expression---\n";
3449-
expr->dump(log);
3450-
log << "\n";
3451-
cs.print(log);
3452-
}
3453-
3454-
// Attempt to solve the constraint system.
3455-
SmallVector<Solution, 4> viable;
3456-
if ((cs.solve(viable) || viable.size() != 1)) {
3457-
// Try to fix the system or provide a decent diagnostic.
3458-
auto salvagedResult = cs.salvage();
3459-
switch (salvagedResult.getKind()) {
3460-
case SolutionResult::Kind::Success:
3461-
viable.clear();
3462-
viable.push_back(std::move(salvagedResult).takeSolution());
3463-
break;
3464-
3465-
case SolutionResult::Kind::Error:
3466-
case SolutionResult::Kind::Ambiguous:
3467-
return true;
3468-
3469-
case SolutionResult::Kind::UndiagnosedError:
3470-
cs.diagnoseFailureForExpr(expr);
3471-
salvagedResult.markAsDiagnosed();
3472-
return true;
3473-
3474-
case SolutionResult::Kind::TooComplex:
3475-
Context.Diags.diagnose(expr->getLoc(), diag::expression_too_complex)
3476-
.highlight(expr->getSourceRange());
3477-
salvagedResult.markAsDiagnosed();
3478-
return true;
3479-
}
3480-
}
3481-
3482-
auto &solution = viable[0];
3483-
if (Context.TypeCheckerOpts.DebugConstraintSolver) {
3484-
auto &log = Context.TypeCheckerDebug->getStream();
3485-
log << "---Solution---\n";
3486-
solution.dump(log);
3487-
}
3488-
3489-
cs.cacheExprTypes(expr);
3490-
3491-
// Perform the conversion.
3492-
Expr *result = solution.coerceToType(expr, type,
3493-
cs.getConstraintLocator(expr),
3494-
typeFromPattern);
3495-
if (!result) {
3496-
return true;
3497-
}
3498-
3499-
solution.setExprTypes(expr);
3500-
3501-
if (Context.TypeCheckerOpts.DebugConstraintSolver) {
3502-
auto &log = Context.TypeCheckerDebug->getStream();
3503-
log << "---Type-checked expression---\n";
3504-
result->dump(log);
3505-
log << "\n";
3506-
}
3507-
3508-
expr = result;
3509-
return false;
3510-
}
3511-
35123522
//===----------------------------------------------------------------------===//
35133523
// Debugging
35143524
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)