@@ -2892,7 +2892,8 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD,
2892
2892
return hadError;
2893
2893
}
2894
2894
2895
- bool TypeChecker::typeCheckForEachBinding (DeclContext *dc, ForEachStmt *stmt) {
2895
+ auto TypeChecker::typeCheckForEachBinding (
2896
+ DeclContext *dc, ForEachStmt *stmt) -> Optional<ForEachBinding> {
2896
2897
// / Type checking listener for for-each binding.
2897
2898
class BindingListener : public ExprTypeCheckListener {
2898
2899
// / The for-each statement.
@@ -2901,45 +2902,66 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
2901
2902
// / The locator we're using.
2902
2903
ConstraintLocator *Locator;
2903
2904
2905
+ // / The contextual locator we're using.
2906
+ ConstraintLocator *ContextualLocator;
2907
+
2908
+ // / The Sequence protocol.
2909
+ ProtocolDecl *SequenceProto;
2910
+
2911
+ // / The IteratorProtocol.
2912
+ ProtocolDecl *IteratorProto;
2913
+
2904
2914
// / The type of the initializer.
2905
2915
Type InitType;
2906
2916
2907
2917
// / The type of the sequence.
2908
2918
Type SequenceType;
2909
2919
2920
+ // / The conformance of the sequence type to the Sequence protocol.
2921
+ ProtocolConformanceRef SequenceConformance;
2922
+
2923
+ // / The type of the element.
2924
+ Type ElementType;
2925
+
2926
+ // / The type of the iterator.
2927
+ Type IteratorType;
2928
+
2929
+ // / The conformance of the iterator type to IteratorProtocol.
2930
+ ProtocolConformanceRef IteratorConformance;
2931
+
2932
+ // / The type of makeIterator.
2933
+ Type MakeIteratorType;
2934
+
2910
2935
public:
2911
2936
explicit BindingListener (ForEachStmt *stmt) : Stmt(stmt) { }
2912
2937
2913
2938
bool builtConstraints (ConstraintSystem &cs, Expr *expr) override {
2914
2939
// Save the locator we're using for the expression.
2915
2940
Locator = cs.getConstraintLocator (expr);
2916
- auto *contextualLocator =
2941
+ ContextualLocator =
2917
2942
cs.getConstraintLocator (expr, LocatorPathElt::ContextualType ());
2918
2943
2919
- // The expression type must conform to the Sequence.
2920
- ProtocolDecl *sequenceProto = TypeChecker::getProtocol (
2944
+ // The expression type must conform to the Sequence protocol .
2945
+ SequenceProto = TypeChecker::getProtocol (
2921
2946
cs.getASTContext (), Stmt->getForLoc (), KnownProtocolKind::Sequence);
2922
- if (!sequenceProto ) {
2947
+ if (!SequenceProto ) {
2923
2948
return true ;
2924
2949
}
2925
2950
2926
- auto elementAssocType =
2927
- sequenceProto->getAssociatedType (cs.getASTContext ().Id_Element );
2928
-
2929
2951
SequenceType = cs.createTypeVariable (Locator, TVO_CanBindToNoEscape);
2930
2952
cs.addConstraint (ConstraintKind::Conversion, cs.getType (expr),
2931
2953
SequenceType, Locator);
2932
2954
cs.addConstraint (ConstraintKind::ConformsTo, SequenceType,
2933
- sequenceProto ->getDeclaredType (), contextualLocator );
2955
+ SequenceProto ->getDeclaredType (), ContextualLocator );
2934
2956
2935
2957
// Since we are using "contextual type" here, it has to be recorded
2936
2958
// in the constraint system for diagnostics to have access to "purpose".
2937
2959
cs.setContextualType (
2938
- expr, TypeLoc::withoutLoc (sequenceProto ->getDeclaredType ()),
2960
+ expr, TypeLoc::withoutLoc (SequenceProto ->getDeclaredType ()),
2939
2961
CTP_ForEachStmt);
2940
2962
2941
2963
auto elementLocator = cs.getConstraintLocator (
2942
- contextualLocator , ConstraintLocator::SequenceElementType);
2964
+ ContextualLocator , ConstraintLocator::SequenceElementType);
2943
2965
2944
2966
// Collect constraints from the element pattern.
2945
2967
auto pattern = Stmt->getPattern ();
@@ -2949,10 +2971,36 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
2949
2971
2950
2972
// Add a conversion constraint between the element type of the sequence
2951
2973
// and the type of the element pattern.
2952
- auto elementType = DependentMemberType::get (SequenceType, elementAssocType);
2953
- cs.addConstraint (ConstraintKind::Conversion, elementType, InitType,
2974
+ auto elementAssocType =
2975
+ SequenceProto->getAssociatedType (cs.getASTContext ().Id_Element );
2976
+ ElementType = DependentMemberType::get (SequenceType, elementAssocType);
2977
+ cs.addConstraint (ConstraintKind::Conversion, ElementType, InitType,
2954
2978
elementLocator);
2955
2979
2980
+ // Determine the iterator type.
2981
+ auto iteratorAssocType =
2982
+ SequenceProto->getAssociatedType (cs.getASTContext ().Id_Iterator );
2983
+ IteratorType = DependentMemberType::get (SequenceType, iteratorAssocType);
2984
+
2985
+ // The iterator type must conform to IteratorProtocol.
2986
+ IteratorProto = TypeChecker::getProtocol (
2987
+ cs.getASTContext (), Stmt->getForLoc (),
2988
+ KnownProtocolKind::IteratorProtocol);
2989
+ if (!IteratorProto) {
2990
+ return true ;
2991
+ }
2992
+
2993
+ // Reference the makeIterator witness.
2994
+ // FIXME: Not tied to the actual witness.
2995
+ ASTContext &ctx = cs.getASTContext ();
2996
+ DeclName makeIteratorName (ctx, ctx.Id_makeIterator ,
2997
+ ArrayRef<Identifier>());
2998
+ MakeIteratorType = cs.createTypeVariable (Locator, TVO_CanBindToNoEscape);
2999
+ cs.addValueMemberConstraint (
3000
+ LValueType::get (SequenceType), DeclNameRef (makeIteratorName),
3001
+ MakeIteratorType, cs.DC , FunctionRefKind::Compound, { },
3002
+ ContextualLocator);
3003
+
2956
3004
Stmt->setSequence (expr);
2957
3005
return false ;
2958
3006
}
@@ -2962,13 +3010,24 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
2962
3010
auto &cs = solution.getConstraintSystem ();
2963
3011
InitType = solution.simplifyType (InitType);
2964
3012
SequenceType = solution.simplifyType (SequenceType);
3013
+ ElementType = solution.simplifyType (ElementType);
3014
+ IteratorType = solution.simplifyType (IteratorType);
2965
3015
2966
3016
// Perform any necessary conversions of the sequence (e.g. [T]! -> [T]).
2967
- expr = solution.coerceToType (expr, SequenceType, cs. getConstraintLocator (expr) );
3017
+ expr = solution.coerceToType (expr, SequenceType, Locator );
2968
3018
2969
3019
if (!expr) return nullptr ;
2970
3020
3021
+ // Convert the sequence as appropriate for the makeIterator() call.
3022
+ auto makeIteratorOverload = solution.getOverloadChoice (ContextualLocator);
3023
+ auto makeIteratorSelfType = solution.simplifyType (
3024
+ makeIteratorOverload.openedFullType
3025
+ )->castTo <AnyFunctionType>()->getParams ()[0 ].getPlainType ();
3026
+ expr = solution.coerceToType (expr, makeIteratorSelfType,
3027
+ ContextualLocator);
3028
+
2971
3029
cs.cacheExprTypes (expr);
3030
+ Stmt->setSequence (expr);
2972
3031
2973
3032
// Apply the solution to the iteration pattern as well.
2974
3033
Pattern *pattern = Stmt->getPattern ();
@@ -2981,11 +3040,42 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
2981
3040
}
2982
3041
2983
3042
Stmt->setPattern (pattern);
2984
- Stmt->setSequence (expr);
3043
+
3044
+ // Get the conformance of the sequence type to the Sequence protocol.
3045
+ // FIXME: Get this from the solution and substitute into that.
3046
+ SequenceConformance = TypeChecker::conformsToProtocol (
3047
+ SequenceType, SequenceProto, cs.DC ,
3048
+ ConformanceCheckFlags::InExpression,
3049
+ expr->getLoc ());
3050
+ assert (!SequenceConformance.isInvalid () &&
3051
+ " Couldn't find sequence conformance" );
3052
+ Stmt->setSequenceConformance (SequenceConformance);
3053
+
3054
+ // Retrieve the conformance of the iterator type to IteratorProtocol.
3055
+ // FIXME: Get this from the solution and substitute into that.
3056
+ IteratorConformance = TypeChecker::conformsToProtocol (
3057
+ IteratorType, IteratorProto, cs.DC ,
3058
+ ConformanceCheckFlags::InExpression,
3059
+ expr->getLoc ());
3060
+
3061
+ // Record the makeIterator declaration we used.
3062
+ auto makeIteratorDecl = makeIteratorOverload.choice .getDecl ();
3063
+ auto makeIteratorSubs = SequenceType->getMemberSubstitutionMap (
3064
+ cs.DC ->getParentModule (), makeIteratorDecl);
3065
+ auto makeIteratorDeclRef =
3066
+ ConcreteDeclRef (makeIteratorDecl, makeIteratorSubs);
3067
+ Stmt->setMakeIterator (makeIteratorDeclRef);
2985
3068
2986
3069
solution.setExprTypes (expr);
2987
3070
return expr;
2988
3071
}
3072
+
3073
+ ForEachBinding getBinding () const {
3074
+ return {
3075
+ SequenceType, SequenceConformance, IteratorType, IteratorConformance,
3076
+ ElementType
3077
+ };
3078
+ }
2989
3079
};
2990
3080
2991
3081
BindingListener listener (stmt);
@@ -2994,7 +3084,9 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
2994
3084
2995
3085
// Type-check the for-each loop sequence and element pattern.
2996
3086
auto resultTy = TypeChecker::typeCheckExpression (seq, dc, &listener);
2997
- return !resultTy;
3087
+ if (!resultTy)
3088
+ return None;
3089
+ return listener.getBinding ();
2998
3090
}
2999
3091
3000
3092
bool TypeChecker::typeCheckCondition (Expr *&expr, DeclContext *dc) {
@@ -3427,88 +3519,6 @@ TypeChecker::coerceToRValue(ASTContext &Context, Expr *expr,
3427
3519
return expr;
3428
3520
}
3429
3521
3430
- bool TypeChecker::convertToType (Expr *&expr, Type type, DeclContext *dc,
3431
- Optional<Pattern*> typeFromPattern) {
3432
- // TODO: need to add kind arg?
3433
- // Construct a constraint system from this expression.
3434
- ConstraintSystem cs (dc, ConstraintSystemFlags::AllowFixes);
3435
-
3436
- // Cache the expression type on the system to ensure it is available
3437
- // on diagnostics if the convertion fails.
3438
- cs.cacheExprTypes (expr);
3439
-
3440
- // If there is a type that we're expected to convert to, add the conversion
3441
- // constraint.
3442
- cs.addConstraint (ConstraintKind::Conversion, expr->getType (), type,
3443
- cs.getConstraintLocator (expr));
3444
-
3445
- auto &Context = dc->getASTContext ();
3446
- if (Context.TypeCheckerOpts .DebugConstraintSolver ) {
3447
- auto &log = Context.TypeCheckerDebug ->getStream ();
3448
- log << " ---Initial constraints for the given expression---\n " ;
3449
- expr->dump (log);
3450
- log << " \n " ;
3451
- cs.print (log);
3452
- }
3453
-
3454
- // Attempt to solve the constraint system.
3455
- SmallVector<Solution, 4 > viable;
3456
- if ((cs.solve (viable) || viable.size () != 1 )) {
3457
- // Try to fix the system or provide a decent diagnostic.
3458
- auto salvagedResult = cs.salvage ();
3459
- switch (salvagedResult.getKind ()) {
3460
- case SolutionResult::Kind::Success:
3461
- viable.clear ();
3462
- viable.push_back (std::move (salvagedResult).takeSolution ());
3463
- break ;
3464
-
3465
- case SolutionResult::Kind::Error:
3466
- case SolutionResult::Kind::Ambiguous:
3467
- return true ;
3468
-
3469
- case SolutionResult::Kind::UndiagnosedError:
3470
- cs.diagnoseFailureForExpr (expr);
3471
- salvagedResult.markAsDiagnosed ();
3472
- return true ;
3473
-
3474
- case SolutionResult::Kind::TooComplex:
3475
- Context.Diags .diagnose (expr->getLoc (), diag::expression_too_complex)
3476
- .highlight (expr->getSourceRange ());
3477
- salvagedResult.markAsDiagnosed ();
3478
- return true ;
3479
- }
3480
- }
3481
-
3482
- auto &solution = viable[0 ];
3483
- if (Context.TypeCheckerOpts .DebugConstraintSolver ) {
3484
- auto &log = Context.TypeCheckerDebug ->getStream ();
3485
- log << " ---Solution---\n " ;
3486
- solution.dump (log);
3487
- }
3488
-
3489
- cs.cacheExprTypes (expr);
3490
-
3491
- // Perform the conversion.
3492
- Expr *result = solution.coerceToType (expr, type,
3493
- cs.getConstraintLocator (expr),
3494
- typeFromPattern);
3495
- if (!result) {
3496
- return true ;
3497
- }
3498
-
3499
- solution.setExprTypes (expr);
3500
-
3501
- if (Context.TypeCheckerOpts .DebugConstraintSolver ) {
3502
- auto &log = Context.TypeCheckerDebug ->getStream ();
3503
- log << " ---Type-checked expression---\n " ;
3504
- result->dump (log);
3505
- log << " \n " ;
3506
- }
3507
-
3508
- expr = result;
3509
- return false ;
3510
- }
3511
-
3512
3522
// ===----------------------------------------------------------------------===//
3513
3523
// Debugging
3514
3524
// ===----------------------------------------------------------------------===//
0 commit comments