@@ -757,20 +757,20 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
757
757
Stmt *visitBraceStmt (BraceStmt *BS);
758
758
759
759
Stmt *visitReturnStmt (ReturnStmt *RS) {
760
- auto TheFunc = AnyFunctionRef::fromDeclContext (DC);
760
+ // First, let's do a pre-check, and bail if the return is completely
761
+ // invalid.
762
+ auto &eval = getASTContext ().evaluator ;
763
+ auto *S =
764
+ evaluateOrDefault (eval, PreCheckReturnStmtRequest{RS, DC}, nullptr );
765
+
766
+ // We do a cast here as it may have been turned into a FailStmt. We should
767
+ // return that without doing anything else.
768
+ RS = dyn_cast_or_null<ReturnStmt>(S);
769
+ if (!RS)
770
+ return S;
761
771
762
- if (!TheFunc.has_value ()) {
763
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
764
- diag::return_invalid_outside_func);
765
- return nullptr ;
766
- }
767
-
768
- // If the return is in a defer, then it isn't valid either.
769
- if (isInDefer ()) {
770
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
771
- diag::jump_out_of_defer, " return" );
772
- return nullptr ;
773
- }
772
+ auto TheFunc = AnyFunctionRef::fromDeclContext (DC);
773
+ assert (TheFunc && " Should have bailed from pre-check if this is None" );
774
774
775
775
Type ResultTy = TheFunc->getBodyResultType ();
776
776
if (!ResultTy || ResultTy->hasError ())
@@ -808,40 +808,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
808
808
}
809
809
810
810
Expr *E = RS->getResult ();
811
-
812
- // In an initializer, the only expression allowed is "nil", which indicates
813
- // failure from a failable initializer.
814
- if (auto ctor = dyn_cast_or_null<ConstructorDecl>(
815
- TheFunc->getAbstractFunctionDecl ())) {
816
- // The only valid return expression in an initializer is the literal
817
- // 'nil'.
818
- auto nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr ());
819
- if (!nilExpr) {
820
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
821
- diag::return_init_non_nil)
822
- .highlight (E->getSourceRange ());
823
- RS->setResult (nullptr );
824
- return RS;
825
- }
826
-
827
- // "return nil" is only permitted in a failable initializer.
828
- if (!ctor->isFailable ()) {
829
- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
830
- diag::return_non_failable_init)
831
- .highlight (E->getSourceRange ());
832
- getASTContext ().Diags .diagnose (ctor->getLoc (), diag::make_init_failable,
833
- ctor->getName ())
834
- .fixItInsertAfter (ctor->getLoc (), " ?" );
835
- RS->setResult (nullptr );
836
- return RS;
837
- }
838
-
839
- // Replace the "return nil" with a new 'fail' statement.
840
- return new (getASTContext ()) FailStmt (RS->getReturnLoc (),
841
- nilExpr->getLoc (),
842
- RS->isImplicit ());
843
- }
844
-
845
811
TypeCheckExprOptions options = {};
846
812
847
813
if (LeaveBraceStmtBodyUnchecked) {
@@ -1294,6 +1260,62 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
1294
1260
};
1295
1261
} // end anonymous namespace
1296
1262
1263
+ Stmt *PreCheckReturnStmtRequest::evaluate (Evaluator &evaluator, ReturnStmt *RS,
1264
+ DeclContext *DC) const {
1265
+ auto &ctx = DC->getASTContext ();
1266
+ auto fn = AnyFunctionRef::fromDeclContext (DC);
1267
+
1268
+ // Not valid outside of a function.
1269
+ if (!fn) {
1270
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_invalid_outside_func);
1271
+ return nullptr ;
1272
+ }
1273
+
1274
+ // If the return is in a defer, then it isn't valid either.
1275
+ if (isDefer (DC)) {
1276
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::jump_out_of_defer, " return" );
1277
+ return nullptr ;
1278
+ }
1279
+
1280
+ // The rest of the checks only concern return statements with results.
1281
+ if (!RS->hasResult ())
1282
+ return RS;
1283
+
1284
+ auto *E = RS->getResult ();
1285
+
1286
+ // In an initializer, the only expression allowed is "nil", which indicates
1287
+ // failure from a failable initializer.
1288
+ if (auto *ctor =
1289
+ dyn_cast_or_null<ConstructorDecl>(fn->getAbstractFunctionDecl ())) {
1290
+
1291
+ // The only valid return expression in an initializer is the literal
1292
+ // 'nil'.
1293
+ auto *nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr ());
1294
+ if (!nilExpr) {
1295
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_init_non_nil)
1296
+ .highlight (E->getSourceRange ());
1297
+ RS->setResult (nullptr );
1298
+ return RS;
1299
+ }
1300
+
1301
+ // "return nil" is only permitted in a failable initializer.
1302
+ if (!ctor->isFailable ()) {
1303
+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_non_failable_init)
1304
+ .highlight (E->getSourceRange ());
1305
+ ctx.Diags
1306
+ .diagnose (ctor->getLoc (), diag::make_init_failable, ctor->getName ())
1307
+ .fixItInsertAfter (ctor->getLoc (), " ?" );
1308
+ RS->setResult (nullptr );
1309
+ return RS;
1310
+ }
1311
+
1312
+ // Replace the "return nil" with a new 'fail' statement.
1313
+ return new (ctx)
1314
+ FailStmt (RS->getReturnLoc (), nilExpr->getLoc (), RS->isImplicit ());
1315
+ }
1316
+ return RS;
1317
+ }
1318
+
1297
1319
static bool isDiscardableType (Type type) {
1298
1320
return (type->hasError () ||
1299
1321
type->isUninhabited () ||
0 commit comments