Skip to content

Commit 3676568

Browse files
authored
Merge pull request swiftlang#26844 from CodaFi/extension-intervention
Requestify Extension Type Validation
2 parents 3304457 + 672cc84 commit 3676568

File tree

17 files changed

+88
-50
lines changed

17 files changed

+88
-50
lines changed

include/swift/AST/Decl.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
16711671
SourceRange Braces;
16721672

16731673
/// The type being extended.
1674-
TypeLoc ExtendedType;
1674+
TypeRepr *ExtendedTypeRepr;
16751675

16761676
/// The nominal type being extended.
16771677
NominalTypeDecl *ExtendedNominal = nullptr;
@@ -1694,7 +1694,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
16941694
friend class ConformanceLookupTable;
16951695
friend class IterableDeclContext;
16961696

1697-
ExtensionDecl(SourceLoc extensionLoc, TypeLoc extendedType,
1697+
ExtensionDecl(SourceLoc extensionLoc, TypeRepr *extendedType,
16981698
MutableArrayRef<TypeLoc> inherited,
16991699
DeclContext *parent,
17001700
TrailingWhereClause *trailingWhereClause);
@@ -1718,7 +1718,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
17181718

17191719
/// Create a new extension declaration.
17201720
static ExtensionDecl *create(ASTContext &ctx, SourceLoc extensionLoc,
1721-
TypeLoc extendedType,
1721+
TypeRepr *extendedType,
17221722
MutableArrayRef<TypeLoc> inherited,
17231723
DeclContext *parent,
17241724
TrailingWhereClause *trailingWhereClause,
@@ -1738,7 +1738,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
17381738
/// Only use this entry point when the complete type, as spelled in the source,
17391739
/// is required. For most clients, \c getExtendedNominal(), which provides
17401740
/// only the \c NominalTypeDecl, will suffice.
1741-
Type getExtendedType() const { return ExtendedType.getType(); }
1741+
Type getExtendedType() const;
17421742

17431743
/// Retrieve the nominal type declaration that is being extended.
17441744
NominalTypeDecl *getExtendedNominal() const;
@@ -1747,12 +1747,9 @@ class ExtensionDecl final : public GenericContext, public Decl,
17471747
/// type declaration.
17481748
bool alreadyBoundToNominal() const { return NextExtension.getInt(); }
17491749

1750-
/// Retrieve the extended type location.
1751-
TypeLoc &getExtendedTypeLoc() { return ExtendedType; }
1752-
1753-
/// Retrieve the extended type location.
1754-
const TypeLoc &getExtendedTypeLoc() const { return ExtendedType; }
1755-
1750+
/// Retrieve the extended type definition as written in the source, if it exists.
1751+
TypeRepr *getExtendedTypeRepr() const { return ExtendedTypeRepr; }
1752+
17561753
/// Retrieve the set of protocols that this type inherits (i.e,
17571754
/// explicitly conforms to).
17581755
MutableArrayRef<TypeLoc> getInherited() { return Inherited; }

include/swift/AST/TypeCheckRequests.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,24 @@ class AbstractGenericSignatureRequest :
11041104
}
11051105
};
11061106

1107+
class ExtendedTypeRequest
1108+
: public SimpleRequest<ExtendedTypeRequest,
1109+
Type(ExtensionDecl *),
1110+
CacheKind::Cached> {
1111+
public:
1112+
using SimpleRequest::SimpleRequest;
1113+
1114+
private:
1115+
friend SimpleRequest;
1116+
1117+
// Evaluation.
1118+
llvm::Expected<Type> evaluate(Evaluator &eval, ExtensionDecl *) const;
1119+
1120+
public:
1121+
// Caching.
1122+
bool isCached() const { return true; }
1123+
};
1124+
11071125
// Allow AnyValue to compare two Type values, even though Type doesn't
11081126
// support ==.
11091127
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@ SWIFT_TYPEID(EmittedMembersRequest)
5858
SWIFT_TYPEID(IsImplicitlyUnwrappedOptionalRequest)
5959
SWIFT_TYPEID(ClassAncestryFlagsRequest)
6060
SWIFT_TYPEID(AbstractGenericSignatureRequest)
61+
SWIFT_TYPEID(ExtendedTypeRequest)

lib/AST/ASTPrinter.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,9 +2107,15 @@ void PrintAST::printExtension(ExtensionDecl *decl) {
21072107
recordDeclLoc(decl, [&]{
21082108
// We cannot extend sugared types.
21092109
Type extendedType = decl->getExtendedType();
2110-
if (!extendedType || !extendedType->getAnyNominal()) {
2110+
if (!extendedType) {
21112111
// Fallback to TypeRepr.
2112-
printTypeLoc(decl->getExtendedTypeLoc());
2112+
printTypeLoc(decl->getExtendedTypeRepr());
2113+
return;
2114+
}
2115+
if (!extendedType->getAnyNominal()) {
2116+
// Fallback to the type. This usually means we're trying to print an
2117+
// UnboundGenericType.
2118+
printTypeLoc(TypeLoc::withoutLoc(extendedType));
21132119
return;
21142120
}
21152121
printExtendedTypeName(extendedType, Printer, Options);

lib/AST/ASTWalker.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
131131
}
132132

133133
bool visitExtensionDecl(ExtensionDecl *ED) {
134-
if (doIt(ED->getExtendedTypeLoc()))
135-
return true;
134+
if (auto *typeRepr = ED->getExtendedTypeRepr())
135+
if (doIt(typeRepr))
136+
return true;
136137
for (auto &Inherit : ED->getInherited()) {
137138
if (doIt(Inherit))
138139
return true;

lib/AST/Decl.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,15 +1053,15 @@ NominalTypeDecl::takeConformanceLoaderSlow() {
10531053
}
10541054

10551055
ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc,
1056-
TypeLoc extendedType,
1056+
TypeRepr *extendedType,
10571057
MutableArrayRef<TypeLoc> inherited,
10581058
DeclContext *parent,
10591059
TrailingWhereClause *trailingWhereClause)
10601060
: GenericContext(DeclContextKind::ExtensionDecl, parent),
10611061
Decl(DeclKind::Extension, parent),
10621062
IterableDeclContext(IterableDeclContextKind::ExtensionDecl),
10631063
ExtensionLoc(extensionLoc),
1064-
ExtendedType(extendedType),
1064+
ExtendedTypeRepr(extendedType),
10651065
Inherited(inherited)
10661066
{
10671067
Bits.ExtensionDecl.DefaultAndMaxAccessLevel = 0;
@@ -1070,7 +1070,7 @@ ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc,
10701070
}
10711071

10721072
ExtensionDecl *ExtensionDecl::create(ASTContext &ctx, SourceLoc extensionLoc,
1073-
TypeLoc extendedType,
1073+
TypeRepr *extendedType,
10741074
MutableArrayRef<TypeLoc> inherited,
10751075
DeclContext *parent,
10761076
TrailingWhereClause *trailingWhereClause,
@@ -1151,6 +1151,13 @@ AccessLevel ExtensionDecl::getMaxAccessLevel() const {
11511151
DefaultAndMaxAccessLevelRequest{const_cast<ExtensionDecl *>(this)},
11521152
{AccessLevel::Private, AccessLevel::Private}).second;
11531153
}
1154+
1155+
Type ExtensionDecl::getExtendedType() const {
1156+
ASTContext &ctx = getASTContext();
1157+
return evaluateOrDefault(ctx.evaluator,
1158+
ExtendedTypeRequest{const_cast<ExtensionDecl *>(this)},
1159+
ErrorType::get(ctx));
1160+
}
11541161

11551162
/// Clone the given generic parameters in the given list. We don't need any
11561163
/// of the requirements, because they will be inferred.
@@ -7622,7 +7629,7 @@ void swift::simple_display(llvm::raw_ostream &out, const Decl *decl) {
76227629
simple_display(out, value);
76237630
} else if (auto ext = dyn_cast<ExtensionDecl>(decl)) {
76247631
out << "extension of ";
7625-
if (auto typeRepr = ext->getExtendedTypeLoc().getTypeRepr())
7632+
if (auto typeRepr = ext->getExtendedTypeRepr())
76267633
typeRepr->print(out);
76277634
else
76287635
ext->getSelfNominalTypeDecl()->dumpRef(out);

lib/AST/NameLookup.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,10 +2128,9 @@ ExtendedNominalRequest::evaluate(Evaluator &evaluator,
21282128
ASTContext &ctx = ext->getASTContext();
21292129

21302130
// Prefer syntactic information when we have it.
2131-
TypeLoc &typeLoc = ext->getExtendedTypeLoc();
2132-
if (auto typeRepr = typeLoc.getTypeRepr()) {
2131+
if (auto typeRepr = ext->getExtendedTypeRepr()) {
21332132
referenced = directReferencesForTypeRepr(evaluator, ctx, typeRepr, ext);
2134-
} else if (auto type = typeLoc.getType()) {
2133+
} else if (auto type = ext->getExtendedType()) {
21352134
// Fall back to semantic types.
21362135
// FIXME: In the long run, we shouldn't need this. Non-syntactic results
21372136
// should be cached.

lib/ClangImporter/ImportDecl.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4545,9 +4545,13 @@ namespace {
45454545
auto loc = Impl.importSourceLoc(decl->getBeginLoc());
45464546
auto result = ExtensionDecl::create(
45474547
Impl.SwiftContext, loc,
4548-
TypeLoc::withoutLoc(objcClass->getDeclaredType()),
4548+
nullptr,
45494549
{ }, dc, nullptr, decl);
4550-
4550+
Impl.SwiftContext
4551+
.evaluator
4552+
.cacheOutput(ExtendedTypeRequest{result},
4553+
objcClass->getDeclaredType());
4554+
45514555
// Determine the type and generic args of the extension.
45524556
if (objcClass->getGenericParams()) {
45534557
result->createGenericParamsIfMissing(objcClass);
@@ -8143,9 +8147,10 @@ ClangImporter::Implementation::importDeclContextOf(
81438147
return knownExtension->second;
81448148

81458149
// Create a new extension for this nominal type/Clang submodule pair.
8146-
auto swiftTyLoc = TypeLoc::withoutLoc(nominal->getDeclaredType());
8147-
auto ext = ExtensionDecl::create(SwiftContext, SourceLoc(), swiftTyLoc, {},
8150+
auto ext = ExtensionDecl::create(SwiftContext, SourceLoc(), nullptr, {},
81488151
getClangModuleForDecl(decl), nullptr);
8152+
SwiftContext.evaluator.cacheOutput(ExtendedTypeRequest{ext},
8153+
nominal->getDeclaredType());
81498154
ext->setValidationToChecked();
81508155
ext->setMemberLoader(this, reinterpret_cast<uintptr_t>(declSubmodule));
81518156

lib/IDE/SourceEntityWalker.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ bool SemaAnnotator::walkToDeclPre(Decl *D) {
137137
return false;
138138
}
139139
} else if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
140-
SourceRange SR = ED->getExtendedTypeLoc().getSourceRange();
140+
SourceRange SR = SourceRange();
141+
if (auto *repr = ED->getExtendedTypeRepr())
142+
SR = repr->getSourceRange();
141143
Loc = SR.Start;
142144
if (Loc.isValid())
143145
NameLen = ED->getASTContext().SourceMgr.getByteDistance(SR.Start, SR.End);
@@ -645,7 +647,9 @@ passReference(ValueDecl *D, Type Ty, SourceLoc BaseNameLoc, SourceRange Range,
645647
}
646648

647649
if (!ExtDecls.empty() && BaseNameLoc.isValid()) {
648-
auto ExtTyLoc = ExtDecls.back()->getExtendedTypeLoc().getLoc();
650+
SourceLoc ExtTyLoc = SourceLoc();
651+
if (auto *repr = ExtDecls.back()->getExtendedTypeRepr())
652+
ExtTyLoc = repr->getLoc();
649653
if (ExtTyLoc.isValid() && ExtTyLoc == BaseNameLoc) {
650654
ExtDecl = ExtDecls.back();
651655
}

lib/IDE/SyntaxModel.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,9 @@ bool ModelASTWalker::walkToDeclPre(Decl *D) {
781781
SN.Kind = SyntaxStructureKind::Extension;
782782
SN.Range = charSourceRangeFromSourceRange(SM, ED->getSourceRange());
783783
SN.BodyRange = innerCharSourceRangeFromSourceRange(SM, ED->getBraces());
784-
SourceRange NSR = ED->getExtendedTypeLoc().getSourceRange();
784+
SourceRange NSR = SourceRange();
785+
if (auto *repr = ED->getExtendedTypeRepr())
786+
NSR = repr->getSourceRange();
785787
SN.NameRange = charSourceRangeFromSourceRange(SM, NSR);
786788

787789
for (const TypeLoc &TL : ED->getInherited()) {

0 commit comments

Comments
 (0)