@@ -1116,17 +1116,39 @@ bool isConditionOfStmt(ConstraintLocatorBuilder locator) {
1116
1116
1117
1117
ConstraintSystem::SolutionKind
1118
1118
ConstraintSystem::simplifySyntacticElementConstraint (
1119
- ASTNode element, ContextualTypeInfo context , bool isDiscarded,
1119
+ ASTNode element, ContextualTypeInfo contextInfo , bool isDiscarded,
1120
1120
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
1121
- auto *closure = castToExpr<ClosureExpr>(locator.getAnchor ());
1122
1121
1123
- SyntacticElementConstraintGenerator generator (
1124
- *this , closure, getClosureType (closure)->getResult (),
1125
- getConstraintLocator (locator));
1122
+ DeclContext *context;
1123
+ Type resultType;
1124
+
1125
+ auto anchor = locator.getAnchor ();
1126
+
1127
+ if (auto *closure = getAsExpr<ClosureExpr>(anchor)) {
1128
+ context = closure;
1129
+ resultType = getClosureType (closure)->getResult ();
1130
+ } else if (auto *fn = getAsDecl<AbstractFunctionDecl>(anchor)) {
1131
+ context = fn;
1132
+ resultType = AnyFunctionRef (fn).getBodyResultType ();
1133
+ } else {
1134
+ return SolutionKind::Error;
1135
+ }
1136
+
1137
+ AnyFunctionRef fn = AnyFunctionRef::fromFunctionDeclContext (context);
1138
+
1139
+ // If this element belongs to a result builder, let's use its result type.
1140
+ {
1141
+ auto transform = resultBuilderTransformed.find (fn);
1142
+ if (transform != resultBuilderTransformed.end ())
1143
+ resultType = transform->second .bodyResultType ;
1144
+ }
1145
+
1146
+ SyntacticElementConstraintGenerator generator (*this , fn, resultType,
1147
+ getConstraintLocator (locator));
1126
1148
1127
1149
if (auto *expr = element.dyn_cast <Expr *>()) {
1128
- SolutionApplicationTarget target (expr, closure, context .purpose ,
1129
- context .getType (), isDiscarded);
1150
+ SolutionApplicationTarget target (expr, context, contextInfo .purpose ,
1151
+ contextInfo .getType (), isDiscarded);
1130
1152
1131
1153
if (generateConstraints (target, FreeTypeVariableBinding::Disallow))
1132
1154
return SolutionKind::Error;
@@ -1136,12 +1158,12 @@ ConstraintSystem::simplifySyntacticElementConstraint(
1136
1158
} else if (auto *stmt = element.dyn_cast <Stmt *>()) {
1137
1159
generator.visit (stmt);
1138
1160
} else if (auto *cond = element.dyn_cast <StmtConditionElement *>()) {
1139
- if (generateConstraints ({*cond}, closure ))
1161
+ if (generateConstraints ({*cond}, context ))
1140
1162
return SolutionKind::Error;
1141
1163
} else if (auto *pattern = element.dyn_cast <Pattern *>()) {
1142
- generator.visitPattern (pattern, context );
1164
+ generator.visitPattern (pattern, contextInfo );
1143
1165
} else if (auto *caseItem = element.dyn_cast <CaseLabelItem *>()) {
1144
- generator.visitCaseItem (caseItem, context );
1166
+ generator.visitCaseItem (caseItem, contextInfo );
1145
1167
} else {
1146
1168
generator.visit (element.get <Decl *>());
1147
1169
}
0 commit comments