@@ -246,21 +246,22 @@ static ProtocolDecl *getSequenceProtocol(ASTContext &ctx, SourceLoc loc,
246
246
}
247
247
248
248
// / Statement visitor that generates constraints for a given closure body.
249
- class ClosureConstraintGenerator
250
- : public StmtVisitor<ClosureConstraintGenerator , void > {
251
- friend StmtVisitor<ClosureConstraintGenerator , void >;
249
+ class SyntacticElementConstraintGenerator
250
+ : public StmtVisitor<SyntacticElementConstraintGenerator , void > {
251
+ friend StmtVisitor<SyntacticElementConstraintGenerator , void >;
252
252
253
253
ConstraintSystem &cs;
254
- ClosureExpr *closure;
254
+ AnyFunctionRef context;
255
+ Type resultType;
255
256
ConstraintLocator *locator;
256
257
257
258
public:
258
259
// / Whether an error was encountered while generating constraints.
259
260
bool hadError = false ;
260
261
261
- ClosureConstraintGenerator (ConstraintSystem &cs, ClosureExpr *closure ,
262
- ConstraintLocator *locator)
263
- : cs(cs), closure(closure ), locator(locator) {}
262
+ SyntacticElementConstraintGenerator (ConstraintSystem &cs, AnyFunctionRef fn ,
263
+ Type resultTy, ConstraintLocator *locator)
264
+ : cs(cs), context(fn), resultType(resultTy ), locator(locator) {}
264
265
265
266
void visitPattern (Pattern *pattern, ContextualTypeInfo context) {
266
267
auto parentElement =
@@ -286,13 +287,13 @@ class ClosureConstraintGenerator
286
287
llvm_unreachable (" Unsupported pattern" );
287
288
}
288
289
289
- void visitCaseItem (CaseLabelItem *caseItem, ContextualTypeInfo context ) {
290
- assert (context .purpose == CTP_CaseStmt);
290
+ void visitCaseItem (CaseLabelItem *caseItem, ContextualTypeInfo contextInfo ) {
291
+ assert (contextInfo .purpose == CTP_CaseStmt);
291
292
292
293
// Resolve the pattern.
293
294
auto *pattern = caseItem->getPattern ();
294
295
if (!caseItem->isPatternResolved ()) {
295
- pattern = TypeChecker::resolvePattern (pattern, closure ,
296
+ pattern = TypeChecker::resolvePattern (pattern, context. getAsDeclContext () ,
296
297
/* isStmtCondition=*/ false );
297
298
if (!pattern) {
298
299
hadError = true ;
@@ -306,13 +307,13 @@ class ClosureConstraintGenerator
306
307
// always be converted into a conjunction.
307
308
308
309
// Generate constraints for pattern.
309
- visitPattern (pattern, context );
310
+ visitPattern (pattern, contextInfo );
310
311
311
312
auto *guardExpr = caseItem->getGuardExpr ();
312
313
313
314
// Generate constraints for `where` clause (if any).
314
315
if (guardExpr) {
315
- guardExpr = cs.generateConstraints (guardExpr, closure );
316
+ guardExpr = cs.generateConstraints (guardExpr, context. getAsDeclContext () );
316
317
if (!guardExpr) {
317
318
hadError = true ;
318
319
return ;
@@ -341,7 +342,7 @@ class ClosureConstraintGenerator
341
342
// Verify pattern.
342
343
{
343
344
auto contextualPattern =
344
- ContextualPattern::forRawPattern (pattern, closure );
345
+ ContextualPattern::forRawPattern (pattern, context. getAsDeclContext () );
345
346
Type patternType = TypeChecker::typeCheckPattern (contextualPattern);
346
347
347
348
if (patternType->hasError ()) {
@@ -380,7 +381,7 @@ class ClosureConstraintGenerator
380
381
381
382
{
382
383
SolutionApplicationTarget target (
383
- sequenceExpr, closure , CTP_ForEachSequence,
384
+ sequenceExpr, context. getAsDeclContext () , CTP_ForEachSequence,
384
385
sequenceProto->getDeclaredInterfaceType (),
385
386
/* isDiscarded=*/ false );
386
387
@@ -426,14 +427,14 @@ class ClosureConstraintGenerator
426
427
Type makeIteratorType =
427
428
cs.createTypeVariable (locator, TVO_CanBindToNoEscape);
428
429
cs.addValueWitnessConstraint (LValueType::get (sequenceType), makeIterator,
429
- makeIteratorType, closure ,
430
+ makeIteratorType, context. getAsDeclContext () ,
430
431
FunctionRefKind::Compound, contextualLocator);
431
432
432
433
// After successful constraint generation, let's record
433
434
// solution application target with all relevant information.
434
435
{
435
436
auto target = SolutionApplicationTarget::forForEachStmt (
436
- forEachStmt, sequenceProto, closure ,
437
+ forEachStmt, sequenceProto, context. getAsDeclContext () ,
437
438
/* bindTypeVarsOneWay=*/ false ,
438
439
/* contextualPurpose=*/ CTP_ForEachSequence);
439
440
@@ -761,9 +762,9 @@ class ClosureConstraintGenerator
761
762
// be handled together with pattern because pattern can
762
763
// inform a type of sequence element e.g. `for i: Int8 in 0 ..< 8`
763
764
{
764
- Pattern *pattern =
765
- TypeChecker::resolvePattern (forEachStmt-> getPattern (), closure ,
766
- /* isStmtCondition=*/ false );
765
+ Pattern *pattern = TypeChecker::resolvePattern (forEachStmt-> getPattern (),
766
+ context. getAsDeclContext () ,
767
+ /* isStmtCondition=*/ false );
767
768
768
769
if (!pattern) {
769
770
hadError = true ;
@@ -798,7 +799,8 @@ class ClosureConstraintGenerator
798
799
{
799
800
elements.push_back (makeElement (subjectExpr, switchLoc));
800
801
801
- SolutionApplicationTarget target (subjectExpr, closure, CTP_Unused,
802
+ SolutionApplicationTarget target (subjectExpr,
803
+ context.getAsDeclContext (), CTP_Unused,
802
804
Type (), /* isDiscarded=*/ false );
803
805
804
806
cs.setSolutionApplicationTarget (switchStmt, target);
@@ -853,7 +855,7 @@ class ClosureConstraintGenerator
853
855
}
854
856
}
855
857
856
- bindSwitchCasePatternVars (closure , caseStmt);
858
+ bindSwitchCasePatternVars (context. getAsDeclContext () , caseStmt);
857
859
858
860
auto *caseLoc = cs.getConstraintLocator (
859
861
locator, LocatorPathElt::SyntacticElement (caseStmt));
@@ -912,7 +914,7 @@ class ClosureConstraintGenerator
912
914
for (auto node : braceStmt->getElements ()) {
913
915
if (auto expr = node.dyn_cast <Expr *>()) {
914
916
auto generatedExpr = cs.generateConstraints (
915
- expr, closure , /* isInputExpression=*/ false );
917
+ expr, context. getAsDeclContext () , /* isInputExpression=*/ false );
916
918
if (!generatedExpr) {
917
919
hadError = true ;
918
920
}
@@ -925,24 +927,27 @@ class ClosureConstraintGenerator
925
927
}
926
928
927
929
void visitReturnStmt (ReturnStmt *returnStmt) {
928
- auto contextualTy = cs.getClosureType (closure)->getResult ();
930
+ auto *closure =
931
+ dyn_cast_or_null<ClosureExpr>(context.getAbstractClosureExpr ());
929
932
930
933
// Single-expression closures are effectively a `return` statement,
931
934
// so let's give them a special locator as to indicate that.
932
935
// Return statements might not have a result if we have a closure whose
933
936
// implicit returned value is coerced to Void.
934
- if (closure->hasSingleExpressionBody () && returnStmt->hasResult ()) {
937
+ if (closure && closure->hasSingleExpressionBody () &&
938
+ returnStmt->hasResult ()) {
935
939
auto *expr = returnStmt->getResult ();
936
940
assert (expr && " single expression closure without expression?" );
937
941
938
- expr = cs.generateConstraints (expr, closure, /* isInputExpression=*/ false );
942
+ expr = cs.generateConstraints (expr, closure,
943
+ /* isInputExpression=*/ false );
939
944
if (!expr) {
940
945
hadError = true ;
941
946
return ;
942
947
}
943
948
944
949
cs.addConstraint (
945
- ConstraintKind::Conversion, cs.getType (expr), contextualTy ,
950
+ ConstraintKind::Conversion, cs.getType (expr), resultType ,
946
951
cs.getConstraintLocator (
947
952
closure, LocatorPathElt::ClosureBody (
948
953
/* hasReturn=*/ !returnStmt->isImplicit ())));
@@ -957,26 +962,30 @@ class ClosureConstraintGenerator
957
962
} else {
958
963
// If this is simplify `return`, let's create an empty tuple
959
964
// which is also useful if contextual turns out to be e.g. `Void?`.
960
- resultExpr = getVoidExpr (closure-> getASTContext ());
965
+ resultExpr = getVoidExpr (cs. getASTContext ());
961
966
}
962
967
963
- SolutionApplicationTarget target (resultExpr, closure, CTP_ReturnStmt ,
964
- contextualTy ,
968
+ SolutionApplicationTarget target (resultExpr, context. getAsDeclContext () ,
969
+ CTP_ReturnStmt, resultType ,
965
970
/* isDiscarded=*/ false );
966
971
967
972
if (cs.generateConstraints (target, FreeTypeVariableBinding::Disallow)) {
968
973
hadError = true ;
969
974
return ;
970
975
}
971
976
972
- cs.setContextualType (target.getAsExpr (), TypeLoc::withoutLoc (contextualTy ),
977
+ cs.setContextualType (target.getAsExpr (), TypeLoc::withoutLoc (resultType ),
973
978
CTP_ReturnStmt);
974
979
cs.setSolutionApplicationTarget (returnStmt, target);
975
980
}
976
981
977
982
bool isSupportedMultiStatementClosure () const {
978
- return !closure->hasSingleExpressionBody () &&
979
- cs.participatesInInference (closure);
983
+ if (auto *closure =
984
+ dyn_cast_or_null<ClosureExpr>(context.getAbstractClosureExpr ())) {
985
+ return !closure->hasSingleExpressionBody () &&
986
+ cs.participatesInInference (closure);
987
+ }
988
+ return true ;
980
989
}
981
990
982
991
#define UNSUPPORTED_STMT (STMT ) void visit##STMT##Stmt(STMT##Stmt *) { \
@@ -1008,8 +1017,10 @@ bool ConstraintSystem::generateConstraints(ClosureExpr *closure) {
1008
1017
auto &ctx = closure->getASTContext ();
1009
1018
1010
1019
if (participatesInInference (closure)) {
1011
- ClosureConstraintGenerator generator (*this , closure,
1012
- getConstraintLocator (closure));
1020
+ SyntacticElementConstraintGenerator generator (
1021
+ *this , closure, getClosureType (closure)->getResult (),
1022
+ getConstraintLocator (closure));
1023
+
1013
1024
generator.visit (closure->getBody ());
1014
1025
1015
1026
if (closure->hasSingleExpressionBody ())
@@ -1079,8 +1090,9 @@ ConstraintSystem::simplifySyntacticElementConstraint(
1079
1090
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
1080
1091
auto *closure = castToExpr<ClosureExpr>(locator.getAnchor ());
1081
1092
1082
- ClosureConstraintGenerator generator (*this , closure,
1083
- getConstraintLocator (locator));
1093
+ SyntacticElementConstraintGenerator generator (
1094
+ *this , closure, getClosureType (closure)->getResult (),
1095
+ getConstraintLocator (locator));
1084
1096
1085
1097
if (auto *expr = element.dyn_cast <Expr *>()) {
1086
1098
SolutionApplicationTarget target (expr, closure, context.purpose ,
0 commit comments