@@ -1645,7 +1645,6 @@ class SyntacticElementSolutionApplication
1645
1645
protected:
1646
1646
Solution &solution;
1647
1647
SyntacticElementContext context;
1648
- Type resultType;
1649
1648
RewriteTargetFn rewriteTarget;
1650
1649
1651
1650
// / All `func`s declared in the body of the closure.
@@ -1658,24 +1657,39 @@ class SyntacticElementSolutionApplication
1658
1657
SyntacticElementSolutionApplication (Solution &solution,
1659
1658
SyntacticElementContext context,
1660
1659
RewriteTargetFn rewriteTarget)
1661
- : solution(solution), context(context), rewriteTarget(rewriteTarget) {
1662
- if (auto fn = AnyFunctionRef::fromDeclContext (context.getAsDeclContext ())) {
1660
+ : solution(solution), context(context), rewriteTarget(rewriteTarget) {}
1661
+
1662
+ virtual ~SyntacticElementSolutionApplication () {}
1663
+
1664
+ private:
1665
+ Type getContextualResultType () const {
1666
+ // Taps do not have a contextual result type.
1667
+ if (context.is <TapExpr *>()) {
1668
+ return Type ();
1669
+ }
1670
+
1671
+ auto fn = context.getAsAnyFunctionRef ();
1672
+
1673
+ if (context.is <SingleValueStmtExpr *>()) {
1674
+ // if/switch expressions can have `return` inside.
1675
+ fn = AnyFunctionRef::fromDeclContext (context.getAsDeclContext ());
1676
+ }
1677
+
1678
+ if (fn) {
1663
1679
if (auto transform = solution.getAppliedBuilderTransform (*fn)) {
1664
- resultType = solution.simplifyType (transform->bodyResultType );
1680
+ return solution.simplifyType (transform->bodyResultType );
1665
1681
} else if (auto *closure =
1666
1682
getAsExpr<ClosureExpr>(fn->getAbstractClosureExpr ())) {
1667
- resultType = solution.getResolvedType (closure)
1668
- ->castTo <FunctionType>()
1669
- ->getResult ();
1683
+ return solution.getResolvedType (closure)
1684
+ ->castTo <FunctionType>()
1685
+ ->getResult ();
1670
1686
} else {
1671
- resultType = fn->getBodyResultType ();
1687
+ return fn->getBodyResultType ();
1672
1688
}
1673
1689
}
1674
- }
1675
-
1676
- virtual ~SyntacticElementSolutionApplication () {}
1677
1690
1678
- private:
1691
+ return Type ();
1692
+ }
1679
1693
1680
1694
ASTNode visit (Stmt *S, bool performSyntacticDiagnostics = true ) {
1681
1695
auto rewritten = ASTVisitor::visit (S);
@@ -2031,17 +2045,18 @@ class SyntacticElementSolutionApplication
2031
2045
auto closure = context.getAsAbstractClosureExpr ();
2032
2046
if (closure && !closure.get ()->hasSingleExpressionBody () &&
2033
2047
closure.get ()->getBody () == braceStmt) {
2048
+ auto resultType = getContextualResultType ();
2034
2049
if (resultType->getOptionalObjectType () &&
2035
2050
resultType->lookThroughAllOptionalTypes ()->isVoid () &&
2036
2051
!braceStmt->getLastElement ().isStmt (StmtKind::Return)) {
2037
- return addImplicitVoidReturn (braceStmt);
2052
+ return addImplicitVoidReturn (braceStmt, resultType );
2038
2053
}
2039
2054
}
2040
2055
2041
2056
return braceStmt;
2042
2057
}
2043
2058
2044
- ASTNode addImplicitVoidReturn (BraceStmt *braceStmt) {
2059
+ ASTNode addImplicitVoidReturn (BraceStmt *braceStmt, Type contextualResultTy ) {
2045
2060
auto &cs = solution.getConstraintSystem ();
2046
2061
auto &ctx = cs.getASTContext ();
2047
2062
@@ -2056,7 +2071,7 @@ class SyntacticElementSolutionApplication
2056
2071
// number of times.
2057
2072
{
2058
2073
SyntacticElementTarget target (resultExpr, context.getAsDeclContext (),
2059
- CTP_ReturnStmt, resultType ,
2074
+ CTP_ReturnStmt, contextualResultTy ,
2060
2075
/* isDiscarded=*/ false );
2061
2076
cs.setTargetFor (returnStmt, target);
2062
2077
@@ -2077,6 +2092,8 @@ class SyntacticElementSolutionApplication
2077
2092
ASTNode visitReturnStmt (ReturnStmt *returnStmt) {
2078
2093
auto &cs = solution.getConstraintSystem ();
2079
2094
2095
+ auto resultType = getContextualResultType ();
2096
+
2080
2097
if (!returnStmt->hasResult ()) {
2081
2098
// If contextual is not optional, there is nothing to do here.
2082
2099
if (resultType->isVoid ())
0 commit comments