Skip to content

Commit e1be9c3

Browse files
committed
Eliminate the DeclContext from ExplicitCaughtTypeRequest
Correctly determining the DeclContext needed for an ExplicitCaughtTypeRequest is tricky for a number of callers, and mistakes here can easily lead to redundant computation of the caught type, redundant diagnostics, etc. Instead, put a `DeclContext` into `DoCatchStmt`, because that's the only catch node that needs a `DeclContext` but does not have one.
1 parent 3f518a3 commit e1be9c3

14 files changed

+78
-50
lines changed

include/swift/AST/CatchNode.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ class CatchNode: public llvm::PointerUnion<
3636
///
3737
/// Returns the thrown error type for a throwing context, or \c llvm::None
3838
/// if this is a non-throwing context.
39-
llvm::Optional<Type> getThrownErrorTypeInContext(DeclContext *dc) const;
39+
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
40+
41+
/// Determines the explicitly-specified type error that will be caught by
42+
/// this catch node.
43+
///
44+
/// Returns the explicitly-caught type, or a NULL type if the caught type
45+
/// needs to be inferred.
46+
Type getExplicitCaughtType(ASTContext &ctx) const;
4047

4148
friend llvm::hash_code hash_value(CatchNode catchNode) {
4249
using llvm::hash_value;
@@ -46,6 +53,8 @@ class CatchNode: public llvm::PointerUnion<
4653

4754
void simple_display(llvm::raw_ostream &out, CatchNode catchNode);
4855

56+
SourceLoc extractNearestSourceLoc(CatchNode catchNode);
57+
4958
} // end namespace swift
5059

5160
#endif // SWIFT_AST_CATCHNODE_H

include/swift/AST/Stmt.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,8 @@ class DoCatchStmt final
13891389
friend TrailingObjects;
13901390
friend class ExplicitCaughtTypeRequest;
13911391

1392+
DeclContext *DC;
1393+
13921394
SourceLoc DoLoc;
13931395

13941396
/// Location of the 'throws' token.
@@ -1400,12 +1402,13 @@ class DoCatchStmt final
14001402
Stmt *Body;
14011403
ThrownErrorDestination RethrowDest;
14021404

1403-
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc,
1405+
DoCatchStmt(DeclContext *dc,LabeledStmtInfo labelInfo, SourceLoc doLoc,
14041406
SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body,
14051407
ArrayRef<CaseStmt *> catches, llvm::Optional<bool> implicit)
14061408
: LabeledStmt(StmtKind::DoCatch, getDefaultImplicitFlag(implicit, doLoc),
14071409
labelInfo),
1408-
DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType), Body(body) {
1410+
DC(dc), DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType),
1411+
Body(body) {
14091412
Bits.DoCatchStmt.NumCatches = catches.size();
14101413
std::uninitialized_copy(catches.begin(), catches.end(),
14111414
getTrailingObjects<CaseStmt *>());
@@ -1414,13 +1417,16 @@ class DoCatchStmt final
14141417
}
14151418

14161419
public:
1417-
static DoCatchStmt *create(ASTContext &ctx, LabeledStmtInfo labelInfo,
1418-
SourceLoc doLoc,
1420+
static DoCatchStmt *create(DeclContext *dc,
1421+
LabeledStmtInfo labelInfo,
1422+
SourceLoc doLoc,
14191423
SourceLoc throwsLoc, TypeLoc thrownType,
14201424
Stmt *body,
14211425
ArrayRef<CaseStmt *> catches,
14221426
llvm::Optional<bool> implicit = llvm::None);
14231427

1428+
DeclContext *getDeclContext() const { return DC; }
1429+
14241430
SourceLoc getDoLoc() const { return DoLoc; }
14251431

14261432
/// Retrieve the location of the 'throws' keyword, if present.
@@ -1435,7 +1441,7 @@ class DoCatchStmt final
14351441
}
14361442

14371443
// Get the explicitly-specified caught error type.
1438-
Type getExplicitCaughtType(DeclContext *dc) const;
1444+
Type getExplicitCaughtType() const;
14391445

14401446
Stmt *getBody() const { return Body; }
14411447
void setBody(Stmt *s) { Body = s; }
@@ -1461,7 +1467,7 @@ class DoCatchStmt final
14611467
// and caught by the various 'catch' clauses. If this the catch clauses
14621468
// aren't exhausive, this is also the type of the error that is implicitly
14631469
// rethrown.
1464-
Type getCaughtErrorType(DeclContext *dc) const;
1470+
Type getCaughtErrorType() const;
14651471

14661472
/// Retrieves the rethrown error and its conversion to the error type
14671473
/// expected by the enclosing context.

include/swift/AST/TypeCheckRequests.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,7 @@ class ParamSpecifierRequest
22942294
/// requires type inference.
22952295
class ExplicitCaughtTypeRequest
22962296
: public SimpleRequest<ExplicitCaughtTypeRequest,
2297-
Type(DeclContext *, CatchNode),
2297+
Type(ASTContext *, CatchNode),
22982298
RequestFlags::SeparatelyCached> {
22992299
public:
23002300
using SimpleRequest::SimpleRequest;
@@ -2303,7 +2303,7 @@ class ExplicitCaughtTypeRequest
23032303
friend SimpleRequest;
23042304

23052305
// Evaluation.
2306-
Type evaluate(Evaluator &evaluator, DeclContext *dc, CatchNode catchNode) const;
2306+
Type evaluate(Evaluator &evaluator, ASTContext *, CatchNode catchNode) const;
23072307

23082308
public:
23092309
// Separate caching.

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ SWIFT_REQUEST(TypeChecker, NeedsNewVTableEntryRequest,
360360
SWIFT_REQUEST(TypeChecker, ParamSpecifierRequest,
361361
ParamDecl::Specifier(ParamDecl *), SeparatelyCached, NoLocationInfo)
362362
SWIFT_REQUEST(TypeChecker, ExplicitCaughtTypeRequest,
363-
Type(DeclContext *, CatchNode), SeparatelyCached, NoLocationInfo)
363+
Type(ASTContext *, CatchNode), SeparatelyCached, NoLocationInfo)
364364
SWIFT_REQUEST(TypeChecker, ResultTypeRequest,
365365
Type(ValueDecl *), SeparatelyCached, NoLocationInfo)
366366
SWIFT_REQUEST(TypeChecker, AreAllStoredPropertiesDefaultInitableRequest,

lib/AST/ASTVerifier.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,9 +1046,8 @@ class Verifier : public ASTWalker {
10461046
SourceLoc loc = S->getThrowLoc();
10471047
if (loc.isValid()) {
10481048
auto catchNode = ASTScope::lookupCatchNode(getModuleContext(), loc);
1049-
DeclContext *dc = getInnermostDC();
1050-
if (catchNode && dc) {
1051-
if (auto thrown = catchNode.getThrownErrorTypeInContext(dc)) {
1049+
if (catchNode) {
1050+
if (auto thrown = catchNode.getThrownErrorTypeInContext(Ctx)) {
10521051
thrownError = *thrown;
10531052
} else {
10541053
thrownError = Ctx.getNeverType();

lib/AST/Decl.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -983,10 +983,7 @@ Type AbstractFunctionDecl::getThrownInterfaceType() const {
983983
return ThrownType.getType();
984984

985985
auto mutableThis = const_cast<AbstractFunctionDecl *>(this);
986-
return evaluateOrDefault(
987-
getASTContext().evaluator,
988-
ExplicitCaughtTypeRequest{mutableThis, mutableThis},
989-
Type());
986+
return CatchNode(mutableThis).getExplicitCaughtType(getASTContext());
990987
}
991988

992989
llvm::Optional<Type>
@@ -11710,7 +11707,7 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
1171011707
}
1171111708

1171211709
llvm::Optional<Type>
11713-
CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
11710+
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
1171411711
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
1171511712
if (auto thrownError = func->getEffectiveThrownErrorType())
1171611713
return func->mapTypeIntoContext(*thrownError);
@@ -11733,15 +11730,15 @@ CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
1173311730
}
1173411731

1173511732
if (auto doCatch = dyn_cast<DoCatchStmt *>()) {
11736-
if (auto thrownError = doCatch->getCaughtErrorType(dc)) {
11733+
if (auto thrownError = doCatch->getCaughtErrorType()) {
1173711734
if (thrownError->isNever())
1173811735
return llvm::None;
1173911736

1174011737
return thrownError;
1174111738
}
1174211739

1174311740
// If we haven't computed the error type yet, return 'any Error'.
11744-
return dc->getASTContext().getErrorExistentialType();
11741+
return ctx.getErrorExistentialType();
1174511742
}
1174611743

1174711744
auto tryExpr = get<AnyTryExpr *>();
@@ -11750,24 +11747,41 @@ CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
1175011747
return thrownError;
1175111748

1175211749
// If we haven't computed the error type yet, return 'any Error'.
11753-
return dc->getASTContext().getErrorExistentialType();
11750+
return ctx.getErrorExistentialType();
1175411751
}
1175511752

1175611753
if (auto optTry = llvm::dyn_cast<OptionalTryExpr>(tryExpr)) {
1175711754
if (auto thrownError = optTry->getThrownError())
1175811755
return thrownError;
1175911756

1176011757
// If we haven't computed the error type yet, return 'any Error'.
11761-
return dc->getASTContext().getErrorExistentialType();
11758+
return ctx.getErrorExistentialType();
1176211759
}
1176311760

1176411761
llvm_unreachable("Unhandled catch node kind");
1176511762
}
1176611763

11764+
Type CatchNode::getExplicitCaughtType(ASTContext &ctx) const {
11765+
return evaluateOrDefault(
11766+
ctx.evaluator, ExplicitCaughtTypeRequest{&ctx, *this}, Type());
11767+
}
11768+
1176711769
void swift::simple_display(llvm::raw_ostream &out, CatchNode catchNode) {
1176811770
out << "catch node";
1176911771
}
1177011772

11773+
SourceLoc swift::extractNearestSourceLoc(CatchNode catchNode) {
11774+
if (auto func = catchNode.dyn_cast<AbstractFunctionDecl *>())
11775+
return func->getLoc();
11776+
if (auto closure = catchNode.dyn_cast<ClosureExpr *>())
11777+
return closure->getLoc();
11778+
if (auto doCatch = catchNode.dyn_cast<DoCatchStmt *>())
11779+
return doCatch->getDoLoc();
11780+
if (auto tryExpr = catchNode.dyn_cast<AnyTryExpr *>())
11781+
return tryExpr->getTryLoc();
11782+
llvm_unreachable("Unhandled catch node");
11783+
}
11784+
1177111785
//----------------------------------------------------------------------------//
1177211786
// ExplicitCaughtTypeRequest computation.
1177311787
//----------------------------------------------------------------------------//

lib/AST/Expr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2054,7 +2054,7 @@ Type ClosureExpr::getExplicitThrownType() const {
20542054

20552055
ASTContext &ctx = getASTContext();
20562056
auto mutableThis = const_cast<ClosureExpr *>(this);
2057-
ExplicitCaughtTypeRequest request{mutableThis, mutableThis};
2057+
ExplicitCaughtTypeRequest request{&ctx, mutableThis};
20582058
return evaluateOrDefault(ctx.evaluator, request, Type());
20592059
}
20602060

lib/AST/Stmt.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,16 +452,18 @@ Expr *ForEachStmt::getTypeCheckedSequence() const {
452452
return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr;
453453
}
454454

455-
DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
455+
DoCatchStmt *DoCatchStmt::create(DeclContext *dc,
456+
LabeledStmtInfo labelInfo,
456457
SourceLoc doLoc,
457458
SourceLoc throwsLoc, TypeLoc thrownType,
458459
Stmt *body,
459460
ArrayRef<CaseStmt *> catches,
460461
llvm::Optional<bool> implicit) {
462+
ASTContext &ctx = dc->getASTContext();
461463
void *mem = ctx.Allocate(totalSizeToAlloc<CaseStmt *>(catches.size()),
462464
alignof(DoCatchStmt));
463-
return ::new (mem) DoCatchStmt(labelInfo, doLoc, throwsLoc, thrownType, body,
464-
catches, implicit);
465+
return ::new (mem) DoCatchStmt(dc, labelInfo, doLoc, throwsLoc, thrownType,
466+
body, catches, implicit);
465467
}
466468

467469
bool CaseLabelItem::isSyntacticallyExhaustive() const {
@@ -478,15 +480,14 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
478480
return false;
479481
}
480482

481-
Type DoCatchStmt::getExplicitCaughtType(DeclContext *dc) const {
482-
ASTContext &ctx = dc->getASTContext();
483-
ExplicitCaughtTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
484-
return evaluateOrDefault(ctx.evaluator, request, Type());
483+
Type DoCatchStmt::getExplicitCaughtType() const {
484+
ASTContext &ctx = DC->getASTContext();
485+
return CatchNode(const_cast<DoCatchStmt *>(this)).getExplicitCaughtType(ctx);
485486
}
486487

487-
Type DoCatchStmt::getCaughtErrorType(DeclContext *dc) const {
488+
Type DoCatchStmt::getCaughtErrorType() const {
488489
// Check for an explicitly-specified error type.
489-
if (Type explicitError = getExplicitCaughtType(dc))
490+
if (Type explicitError = getExplicitCaughtType())
490491
return explicitError;
491492

492493
auto firstPattern = getCatches()

lib/Parse/ParseStmt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,8 +2255,8 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
22552255
}
22562256

22572257
return makeParserResult(status,
2258-
DoCatchStmt::create(Context, labelInfo, doLoc, throwsLoc, thrownType,
2259-
body.get(), allClauses));
2258+
DoCatchStmt::create(CurDeclContext, labelInfo, doLoc, throwsLoc,
2259+
thrownType, body.get(), allClauses));
22602260
}
22612261

22622262
if (throwsLoc.isValid()) {

lib/SILGen/SILGenStmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ void StmtEmitter::visitDoStmt(DoStmt *S) {
11171117
}
11181118

11191119
void StmtEmitter::visitDoCatchStmt(DoCatchStmt *S) {
1120-
Type formalExnType = S->getCaughtErrorType(SGF.FunctionDC);
1120+
Type formalExnType = S->getCaughtErrorType();
11211121
auto &exnTL = SGF.getTypeLowering(formalExnType);
11221122

11231123
SILValue exnArg;

0 commit comments

Comments
 (0)