Skip to content

Commit ab5ab28

Browse files
authored
Merge pull request #70454 from DougGregor/full-typed-throws-inference
[Typed throws] Implement thrown type inference for do..catch within closures
2 parents 979334e + ae5f66a commit ab5ab28

22 files changed

+471
-81
lines changed

include/swift/AST/CatchNode.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ class CatchNode: public llvm::PointerUnion<
3636
///
3737
/// Returns the thrown error type for a throwing context, or \c llvm::None
3838
/// if this is a non-throwing context.
39-
llvm::Optional<Type> getThrownErrorTypeInContext(DeclContext *dc) const;
39+
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
40+
41+
/// Determines the explicitly-specified type error that will be caught by
42+
/// this catch node.
43+
///
44+
/// Returns the explicitly-caught type, or a NULL type if the caught type
45+
/// needs to be inferred.
46+
Type getExplicitCaughtType(ASTContext &ctx) const;
4047

4148
friend llvm::hash_code hash_value(CatchNode catchNode) {
4249
using llvm::hash_value;
@@ -46,6 +53,8 @@ class CatchNode: public llvm::PointerUnion<
4653

4754
void simple_display(llvm::raw_ostream &out, CatchNode catchNode);
4855

56+
SourceLoc extractNearestSourceLoc(CatchNode catchNode);
57+
4958
} // end namespace swift
5059

5160
#endif // SWIFT_AST_CATCHNODE_H

include/swift/AST/Stmt.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,8 @@ class DoCatchStmt final
14291429
friend TrailingObjects;
14301430
friend class ExplicitCaughtTypeRequest;
14311431

1432+
DeclContext *DC;
1433+
14321434
SourceLoc DoLoc;
14331435

14341436
/// Location of the 'throws' token.
@@ -1440,12 +1442,13 @@ class DoCatchStmt final
14401442
Stmt *Body;
14411443
ThrownErrorDestination RethrowDest;
14421444

1443-
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc,
1445+
DoCatchStmt(DeclContext *dc,LabeledStmtInfo labelInfo, SourceLoc doLoc,
14441446
SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body,
14451447
ArrayRef<CaseStmt *> catches, llvm::Optional<bool> implicit)
14461448
: LabeledStmt(StmtKind::DoCatch, getDefaultImplicitFlag(implicit, doLoc),
14471449
labelInfo),
1448-
DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType), Body(body) {
1450+
DC(dc), DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType),
1451+
Body(body) {
14491452
Bits.DoCatchStmt.NumCatches = catches.size();
14501453
std::uninitialized_copy(catches.begin(), catches.end(),
14511454
getTrailingObjects<CaseStmt *>());
@@ -1454,13 +1457,16 @@ class DoCatchStmt final
14541457
}
14551458

14561459
public:
1457-
static DoCatchStmt *create(ASTContext &ctx, LabeledStmtInfo labelInfo,
1458-
SourceLoc doLoc,
1460+
static DoCatchStmt *create(DeclContext *dc,
1461+
LabeledStmtInfo labelInfo,
1462+
SourceLoc doLoc,
14591463
SourceLoc throwsLoc, TypeLoc thrownType,
14601464
Stmt *body,
14611465
ArrayRef<CaseStmt *> catches,
14621466
llvm::Optional<bool> implicit = llvm::None);
14631467

1468+
DeclContext *getDeclContext() const { return DC; }
1469+
14641470
SourceLoc getDoLoc() const { return DoLoc; }
14651471

14661472
/// Retrieve the location of the 'throws' keyword, if present.
@@ -1475,7 +1481,7 @@ class DoCatchStmt final
14751481
}
14761482

14771483
// Get the explicitly-specified caught error type.
1478-
Type getExplicitCaughtType(DeclContext *dc) const;
1484+
Type getExplicitCaughtType() const;
14791485

14801486
Stmt *getBody() const { return Body; }
14811487
void setBody(Stmt *s) { Body = s; }
@@ -1501,7 +1507,7 @@ class DoCatchStmt final
15011507
// and caught by the various 'catch' clauses. If this the catch clauses
15021508
// aren't exhausive, this is also the type of the error that is implicitly
15031509
// rethrown.
1504-
Type getCaughtErrorType(DeclContext *dc) const;
1510+
Type getCaughtErrorType() const;
15051511

15061512
/// Retrieves the rethrown error and its conversion to the error type
15071513
/// expected by the enclosing context.

include/swift/AST/TypeCheckRequests.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,7 @@ class ParamSpecifierRequest
22942294
/// requires type inference.
22952295
class ExplicitCaughtTypeRequest
22962296
: public SimpleRequest<ExplicitCaughtTypeRequest,
2297-
Type(DeclContext *, CatchNode),
2297+
Type(ASTContext *, CatchNode),
22982298
RequestFlags::SeparatelyCached> {
22992299
public:
23002300
using SimpleRequest::SimpleRequest;
@@ -2303,7 +2303,7 @@ class ExplicitCaughtTypeRequest
23032303
friend SimpleRequest;
23042304

23052305
// Evaluation.
2306-
Type evaluate(Evaluator &evaluator, DeclContext *dc, CatchNode catchNode) const;
2306+
Type evaluate(Evaluator &evaluator, ASTContext *, CatchNode catchNode) const;
23072307

23082308
public:
23092309
// Separate caching.

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ SWIFT_REQUEST(TypeChecker, NeedsNewVTableEntryRequest,
360360
SWIFT_REQUEST(TypeChecker, ParamSpecifierRequest,
361361
ParamDecl::Specifier(ParamDecl *), SeparatelyCached, NoLocationInfo)
362362
SWIFT_REQUEST(TypeChecker, ExplicitCaughtTypeRequest,
363-
Type(DeclContext *, CatchNode), SeparatelyCached, NoLocationInfo)
363+
Type(ASTContext *, CatchNode), SeparatelyCached, NoLocationInfo)
364364
SWIFT_REQUEST(TypeChecker, ResultTypeRequest,
365365
Type(ValueDecl *), SeparatelyCached, NoLocationInfo)
366366
SWIFT_REQUEST(TypeChecker, AreAllStoredPropertiesDefaultInitableRequest,

include/swift/Sema/ConstraintSystem.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,34 @@ struct MatchCallArgumentResult {
14401440
}
14411441
};
14421442

1443+
/// Describes a potential throw site in the constraint system.
1444+
///
1445+
/// For example, given `try f() + a[b] + x.y`, each of `f()`, `a[b]`, `x`, and
1446+
/// `x.y` is a potential throw site.
1447+
struct PotentialThrowSite {
1448+
enum Kind {
1449+
/// The application of a function or subscript.
1450+
Application,
1451+
1452+
/// An explicit 'throw'.
1453+
ExplicitThrow,
1454+
1455+
/// A non-exhaustive do...catch, which rethrows whatever is thrown from
1456+
/// inside it's `do` block.
1457+
NonExhaustiveDoCatch,
1458+
1459+
/// A property access that can throw an error.
1460+
PropertyAccess,
1461+
} kind;
1462+
1463+
/// The type that describes the potential throw site, such as the type of the
1464+
/// function being called or type being thrown.
1465+
Type type;
1466+
1467+
/// The locator that specifies where the throwing operation occurs.
1468+
ConstraintLocator *locator;
1469+
};
1470+
14431471
/// A complete solution to a constraint system.
14441472
///
14451473
/// A solution to a constraint system consists of type variable bindings to
@@ -1547,6 +1575,13 @@ class Solution {
15471575
llvm::MapVector<const CaseLabelItem *, CaseLabelItemInfo>
15481576
caseLabelItems;
15491577

1578+
/// Maps catch nodes to the set of potential throw sites that will be caught
1579+
/// at that location.
1580+
1581+
/// The set of opened types for a given locator.
1582+
std::vector<std::pair<CatchNode, PotentialThrowSite>>
1583+
potentialThrowSites;
1584+
15501585
/// A map of expressions to the ExprPatterns that they are being solved as
15511586
/// a part of.
15521587
llvm::MapVector<Expr *, ExprPattern *> exprPatterns;
@@ -2038,6 +2073,9 @@ struct DeclReferenceType {
20382073
/// (e.g.) applying the base of a member access. This is the type of the
20392074
/// expression used to form the declaration reference.
20402075
Type adjustedReferenceType;
2076+
2077+
/// The type that could be thrown by accessing this declaration.
2078+
Type thrownErrorTypeOnAccess;
20412079
};
20422080

20432081
/// Describes a system of constraints on type variables, the
@@ -2210,6 +2248,11 @@ class ConstraintSystem {
22102248
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
22112249
caseLabelItems;
22122250

2251+
/// Keep track of all of the potential throw sites.
2252+
/// FIXME: This data structure should be replaced with something that
2253+
/// is, in effect, a multimap-vector.
2254+
std::vector<std::pair<CatchNode, PotentialThrowSite>> potentialThrowSites;
2255+
22132256
/// A map of expressions to the ExprPatterns that they are being solved as
22142257
/// a part of.
22152258
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
@@ -2821,6 +2864,9 @@ class ConstraintSystem {
28212864
/// The length of \c caseLabelItems.
28222865
unsigned numCaseLabelItems;
28232866

2867+
/// The length of \c potentialThrowSites.
2868+
unsigned numPotentialThrowSites;
2869+
28242870
/// The length of \c exprPatterns.
28252871
unsigned numExprPatterns;
28262872

@@ -3293,6 +3339,14 @@ class ConstraintSystem {
32933339
return known->second;
32943340
}
32953341

3342+
/// Note that there is a potential throw site at the given location.
3343+
void recordPotentialThrowSite(
3344+
PotentialThrowSite::Kind kind, Type type,
3345+
ConstraintLocatorBuilder locator);
3346+
3347+
/// Determine the caught error type for the given catch node.
3348+
Type getCaughtErrorType(CatchNode node);
3349+
32963350
/// Retrieve the constraint locator for the given anchor and
32973351
/// path, uniqued.
32983352
ConstraintLocator *

lib/AST/ASTVerifier.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,9 +1046,8 @@ class Verifier : public ASTWalker {
10461046
SourceLoc loc = S->getThrowLoc();
10471047
if (loc.isValid()) {
10481048
auto catchNode = ASTScope::lookupCatchNode(getModuleContext(), loc);
1049-
DeclContext *dc = getInnermostDC();
1050-
if (catchNode && dc) {
1051-
if (auto thrown = catchNode.getThrownErrorTypeInContext(dc)) {
1049+
if (catchNode) {
1050+
if (auto thrown = catchNode.getThrownErrorTypeInContext(Ctx)) {
10521051
thrownError = *thrown;
10531052
} else {
10541053
thrownError = Ctx.getNeverType();

lib/AST/Decl.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -983,10 +983,7 @@ Type AbstractFunctionDecl::getThrownInterfaceType() const {
983983
return ThrownType.getType();
984984

985985
auto mutableThis = const_cast<AbstractFunctionDecl *>(this);
986-
return evaluateOrDefault(
987-
getASTContext().evaluator,
988-
ExplicitCaughtTypeRequest{mutableThis, mutableThis},
989-
Type());
986+
return CatchNode(mutableThis).getExplicitCaughtType(getASTContext());
990987
}
991988

992989
llvm::Optional<Type>
@@ -11728,7 +11725,7 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
1172811725
}
1172911726

1173011727
llvm::Optional<Type>
11731-
CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
11728+
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
1173211729
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
1173311730
if (auto thrownError = func->getEffectiveThrownErrorType())
1173411731
return func->mapTypeIntoContext(*thrownError);
@@ -11751,15 +11748,15 @@ CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
1175111748
}
1175211749

1175311750
if (auto doCatch = dyn_cast<DoCatchStmt *>()) {
11754-
if (auto thrownError = doCatch->getCaughtErrorType(dc)) {
11751+
if (auto thrownError = doCatch->getCaughtErrorType()) {
1175511752
if (thrownError->isNever())
1175611753
return llvm::None;
1175711754

1175811755
return thrownError;
1175911756
}
1176011757

1176111758
// If we haven't computed the error type yet, return 'any Error'.
11762-
return dc->getASTContext().getErrorExistentialType();
11759+
return ctx.getErrorExistentialType();
1176311760
}
1176411761

1176511762
auto tryExpr = get<AnyTryExpr *>();
@@ -11768,24 +11765,41 @@ CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
1176811765
return thrownError;
1176911766

1177011767
// If we haven't computed the error type yet, return 'any Error'.
11771-
return dc->getASTContext().getErrorExistentialType();
11768+
return ctx.getErrorExistentialType();
1177211769
}
1177311770

1177411771
if (auto optTry = llvm::dyn_cast<OptionalTryExpr>(tryExpr)) {
1177511772
if (auto thrownError = optTry->getThrownError())
1177611773
return thrownError;
1177711774

1177811775
// If we haven't computed the error type yet, return 'any Error'.
11779-
return dc->getASTContext().getErrorExistentialType();
11776+
return ctx.getErrorExistentialType();
1178011777
}
1178111778

1178211779
llvm_unreachable("Unhandled catch node kind");
1178311780
}
1178411781

11782+
Type CatchNode::getExplicitCaughtType(ASTContext &ctx) const {
11783+
return evaluateOrDefault(
11784+
ctx.evaluator, ExplicitCaughtTypeRequest{&ctx, *this}, Type());
11785+
}
11786+
1178511787
void swift::simple_display(llvm::raw_ostream &out, CatchNode catchNode) {
1178611788
out << "catch node";
1178711789
}
1178811790

11791+
SourceLoc swift::extractNearestSourceLoc(CatchNode catchNode) {
11792+
if (auto func = catchNode.dyn_cast<AbstractFunctionDecl *>())
11793+
return func->getLoc();
11794+
if (auto closure = catchNode.dyn_cast<ClosureExpr *>())
11795+
return closure->getLoc();
11796+
if (auto doCatch = catchNode.dyn_cast<DoCatchStmt *>())
11797+
return doCatch->getDoLoc();
11798+
if (auto tryExpr = catchNode.dyn_cast<AnyTryExpr *>())
11799+
return tryExpr->getTryLoc();
11800+
llvm_unreachable("Unhandled catch node");
11801+
}
11802+
1178911803
//----------------------------------------------------------------------------//
1179011804
// ExplicitCaughtTypeRequest computation.
1179111805
//----------------------------------------------------------------------------//

lib/AST/Expr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2054,7 +2054,7 @@ Type ClosureExpr::getExplicitThrownType() const {
20542054

20552055
ASTContext &ctx = getASTContext();
20562056
auto mutableThis = const_cast<ClosureExpr *>(this);
2057-
ExplicitCaughtTypeRequest request{mutableThis, mutableThis};
2057+
ExplicitCaughtTypeRequest request{&ctx, mutableThis};
20582058
return evaluateOrDefault(ctx.evaluator, request, Type());
20592059
}
20602060

lib/AST/Stmt.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,16 +452,18 @@ Expr *ForEachStmt::getTypeCheckedSequence() const {
452452
return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr;
453453
}
454454

455-
DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
455+
DoCatchStmt *DoCatchStmt::create(DeclContext *dc,
456+
LabeledStmtInfo labelInfo,
456457
SourceLoc doLoc,
457458
SourceLoc throwsLoc, TypeLoc thrownType,
458459
Stmt *body,
459460
ArrayRef<CaseStmt *> catches,
460461
llvm::Optional<bool> implicit) {
462+
ASTContext &ctx = dc->getASTContext();
461463
void *mem = ctx.Allocate(totalSizeToAlloc<CaseStmt *>(catches.size()),
462464
alignof(DoCatchStmt));
463-
return ::new (mem) DoCatchStmt(labelInfo, doLoc, throwsLoc, thrownType, body,
464-
catches, implicit);
465+
return ::new (mem) DoCatchStmt(dc, labelInfo, doLoc, throwsLoc, thrownType,
466+
body, catches, implicit);
465467
}
466468

467469
bool CaseLabelItem::isSyntacticallyExhaustive() const {
@@ -478,15 +480,14 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
478480
return false;
479481
}
480482

481-
Type DoCatchStmt::getExplicitCaughtType(DeclContext *dc) const {
482-
ASTContext &ctx = dc->getASTContext();
483-
ExplicitCaughtTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
484-
return evaluateOrDefault(ctx.evaluator, request, Type());
483+
Type DoCatchStmt::getExplicitCaughtType() const {
484+
ASTContext &ctx = DC->getASTContext();
485+
return CatchNode(const_cast<DoCatchStmt *>(this)).getExplicitCaughtType(ctx);
485486
}
486487

487-
Type DoCatchStmt::getCaughtErrorType(DeclContext *dc) const {
488+
Type DoCatchStmt::getCaughtErrorType() const {
488489
// Check for an explicitly-specified error type.
489-
if (Type explicitError = getExplicitCaughtType(dc))
490+
if (Type explicitError = getExplicitCaughtType())
490491
return explicitError;
491492

492493
auto firstPattern = getCatches()

lib/Parse/ParseStmt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,8 +2256,8 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
22562256
}
22572257

22582258
return makeParserResult(status,
2259-
DoCatchStmt::create(Context, labelInfo, doLoc, throwsLoc, thrownType,
2260-
body.get(), allClauses));
2259+
DoCatchStmt::create(CurDeclContext, labelInfo, doLoc, throwsLoc,
2260+
thrownType, body.get(), allClauses));
22612261
}
22622262

22632263
if (throwsLoc.isValid()) {

0 commit comments

Comments
 (0)