Skip to content

Commit f667c43

Browse files
committed
[NFC] Extract autodiff parsing code to Decl.cpp
The `@differentiable` and `@derivative` attributes need a parent pointer. Move the code to populate it from Parser to AST so it can be more easily shared between the parsers. Done in preparation for similar code to be added for `@abi`.
1 parent 52693c4 commit f667c43

File tree

6 files changed

+57
-55
lines changed

6 files changed

+57
-55
lines changed

include/swift/AST/Attr.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,12 @@ template <typename ATTR, bool AllowInvalid> struct ToAttributeKind {
28702870
return cast<ATTR>(Attr);
28712871
return std::nullopt;
28722872
}
2873+
2874+
std::optional<ATTR *> operator()(DeclAttribute *Attr) {
2875+
if (isa<ATTR>(Attr) && (Attr->isValid() || AllowInvalid))
2876+
return cast<ATTR>(Attr);
2877+
return std::nullopt;
2878+
}
28732879
};
28742880

28752881
/// The @_allowFeatureSuppression(Foo, Bar) attribute. The feature
@@ -3042,17 +3048,25 @@ class DeclAttributes {
30423048
}
30433049

30443050
public:
3045-
template <typename ATTR, bool AllowInvalid>
3051+
template <typename ATTR, typename Iterator, bool AllowInvalid>
30463052
using AttributeKindRange =
3047-
OptionalTransformRange<iterator_range<const_iterator>,
3053+
OptionalTransformRange<iterator_range<Iterator>,
30483054
ToAttributeKind<ATTR, AllowInvalid>,
3049-
const_iterator>;
3055+
Iterator>;
3056+
3057+
/// Return a range with all attributes in DeclAttributes with AttrKind
3058+
/// ATTR.
3059+
template <typename ATTR, bool AllowInvalid = false>
3060+
AttributeKindRange<ATTR, const_iterator, AllowInvalid> getAttributes() const {
3061+
return AttributeKindRange<ATTR, const_iterator, AllowInvalid>(
3062+
make_range(begin(), end()), ToAttributeKind<ATTR, AllowInvalid>());
3063+
}
30503064

30513065
/// Return a range with all attributes in DeclAttributes with AttrKind
30523066
/// ATTR.
30533067
template <typename ATTR, bool AllowInvalid = false>
3054-
AttributeKindRange<ATTR, AllowInvalid> getAttributes() const {
3055-
return AttributeKindRange<ATTR, AllowInvalid>(
3068+
AttributeKindRange<ATTR, iterator, AllowInvalid> getAttributes() {
3069+
return AttributeKindRange<ATTR, iterator, AllowInvalid>(
30563070
make_range(begin(), end()), ToAttributeKind<ATTR, AllowInvalid>());
30573071
}
30583072

include/swift/AST/Decl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,11 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl>, public Swi
10211021
/// attribute macro expansion.
10221022
DeclAttributes getSemanticAttrs() const;
10231023

1024+
/// Set this declaration's attributes to the specified attribute list,
1025+
/// applying any post-processing logic appropriate for attributes parsed
1026+
/// from source code.
1027+
void attachParsedAttrs(DeclAttributes attrs);
1028+
10241029
/// True if this declaration provides an implementation for an imported
10251030
/// Objective-C declaration. This implies various restrictions and special
10261031
/// behaviors for it and, if it's an extension, its members.

lib/AST/Decl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,17 @@ DeclAttributes Decl::getSemanticAttrs() const {
407407
return getAttrs();
408408
}
409409

410+
void Decl::attachParsedAttrs(DeclAttributes attrs) {
411+
ASSERT(getAttrs().isEmpty() && "attaching when there are already attrs?");
412+
413+
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
414+
attr->setOriginalDeclaration(this);
415+
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
416+
attr->setOriginalDeclaration(this);
417+
418+
getAttrs() = attrs;
419+
}
420+
410421
void Decl::visitAuxiliaryDecls(
411422
AuxiliaryDeclCallback callback,
412423
bool visitFreestandingExpanded

lib/Parse/ParseDecl.cpp

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6000,20 +6000,6 @@ void Parser::consumeDecl(ParserPosition BeginParserPosition, bool IsTopLevel) {
60006000
}
60016001
}
60026002

6003-
/// Set the original declaration in `@differentiable` attributes.
6004-
///
6005-
/// Necessary because `Parser::parseNewDeclAttribute` (which calls
6006-
/// `Parser::parseDifferentiableAttribute`) does not have access to the
6007-
/// parent declaration of parsed attributes.
6008-
static void
6009-
setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
6010-
Decl *D) {
6011-
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
6012-
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
6013-
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
6014-
const_cast<DerivativeAttr *>(attr)->setOriginalDeclaration(D);
6015-
}
6016-
60176003
/// Determine the declaration parsing options to use when parsing the decl in
60186004
/// the given context.
60196005
static Parser::ParseDeclOptions getParseDeclOptions(DeclContext *DC) {
@@ -6439,7 +6425,6 @@ ParserStatus Parser::parseDecl(bool IsAtStartOfLineOrPreviousHadSemi,
64396425
Decl *D = DeclResult.get();
64406426
if (!HandlerAlreadyCalled)
64416427
Handler(D);
6442-
setOriginalDeclarationForDifferentiableAttributes(D->getAttrs(), D);
64436428
}
64446429

64456430
if (!DeclResult.isParseErrorOrHasCompletion()) {
@@ -6635,7 +6620,7 @@ ParserResult<ImportDecl> Parser::parseDeclImport(ParseDeclOptions Flags,
66356620

66366621
auto *ID = ImportDecl::create(Context, CurDeclContext, ImportLoc, Kind,
66376622
KindLoc, importPath.get());
6638-
ID->getAttrs() = Attributes;
6623+
ID->attachParsedAttrs(Attributes);
66396624
return DCC.fixupParserResult(ID);
66406625
}
66416626

@@ -7079,7 +7064,7 @@ Parser::parseDeclExtension(ParseDeclOptions Flags, DeclAttributes &Attributes) {
70797064
Context.AllocateCopy(Inherited),
70807065
CurDeclContext,
70817066
trailingWhereClause);
7082-
ext->getAttrs() = Attributes;
7067+
ext->attachParsedAttrs(Attributes);
70837068
if (trailingWhereHadCodeCompletion && CodeCompletionCallbacks)
70847069
CodeCompletionCallbacks->setParsedDecl(ext);
70857070

@@ -7414,7 +7399,7 @@ parseDeclTypeAlias(Parser::ParseDeclOptions Flags, DeclAttributes &Attributes) {
74147399
}
74157400

74167401
TAD->setUnderlyingTypeRepr(UnderlyingTy.getPtrOrNull());
7417-
TAD->getAttrs() = Attributes;
7402+
TAD->attachParsedAttrs(Attributes);
74187403

74197404
// Parse a 'where' clause if present.
74207405
if (Tok.is(tok::kw_where)) {
@@ -7531,7 +7516,7 @@ ParserResult<TypeDecl> Parser::parseDeclAssociatedType(Parser::ParseDeclOptions
75317516
auto assocType = AssociatedTypeDecl::createParsed(
75327517
Context, CurDeclContext, AssociatedTypeLoc, Id, IdLoc,
75337518
UnderlyingTy.getPtrOrNull(), TrailingWhere);
7534-
assocType->getAttrs() = Attributes;
7519+
assocType->attachParsedAttrs(Attributes);
75357520
if (!Inherited.empty())
75367521
assocType->setInherited(Context.AllocateCopy(Inherited));
75377522
return makeParserResult(Status, assocType);
@@ -7886,7 +7871,7 @@ bool Parser::parseAccessorAfterIntroducer(
78867871
auto *accessor = AccessorDecl::createParsed(
78877872
Context, Kind, storage, /*declLoc*/ Loc, /*accessorKeywordLoc*/ Loc,
78887873
param, asyncLoc, throwsLoc, thrownTy, CurDeclContext);
7889-
accessor->getAttrs() = Attributes;
7874+
accessor->attachParsedAttrs(Attributes);
78907875

78917876
// Collect this accessor and detect conflicts.
78927877
if (auto existingAccessor = accessors.add(accessor)) {
@@ -8165,7 +8150,7 @@ void Parser::parseExpandedAttributeList(SmallVectorImpl<ASTNode> &items,
81658150
// macro will attach the attribute list to.
81668151
MissingDecl *missing =
81678152
MissingDecl::create(Context, CurDeclContext, Tok.getLoc());
8168-
missing->getAttrs() = attributes;
8153+
missing->attachParsedAttrs(attributes);
81698154

81708155
items.push_back(ASTNode(missing));
81718156
return;
@@ -8295,11 +8280,6 @@ Parser::parseDeclVarGetSet(PatternBindingEntry &entry, ParseDeclOptions Flags,
82958280

82968281
accessors.record(*this, PrimaryVar, Invalid);
82978282

8298-
// Set original declaration in `@differentiable` attributes.
8299-
for (auto *accessor : accessors.Accessors)
8300-
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
8301-
accessor);
8302-
83038283
return makeParserResult(AccessorStatus, PrimaryVar);
83048284
}
83058285

@@ -8589,12 +8569,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
85898569
// Configure all vars with attributes, 'static' and parent pattern.
85908570
pattern->forEachVariable([&](VarDecl *VD) {
85918571
VD->setStatic(StaticLoc.isValid());
8592-
VD->getAttrs() = Attributes;
8572+
VD->attachParsedAttrs(Attributes);
85938573
VD->setTopLevelGlobal(topLevelDecl);
85948574

8595-
// Set original declaration in `@differentiable` attributes.
8596-
setOriginalDeclarationForDifferentiableAttributes(Attributes, VD);
8597-
85988575
Decls.push_back(VD);
85998576
if (hasOpaqueReturnTy && sf && !InInactiveClauseEnvironment
86008577
&& !InFreestandingMacroArgument) {
@@ -8921,7 +8898,7 @@ ParserResult<FuncDecl> Parser::parseDeclFunc(SourceLoc StaticLoc,
89218898
diagnoseOperatorFixityAttributes(*this, Attributes, FD);
89228899
// Add the attributes here so if we need them while parsing the body
89238900
// they are available.
8924-
FD->getAttrs() = Attributes;
8901+
FD->attachParsedAttrs(Attributes);
89258902

89268903
// Pass the function signature to code completion.
89278904
if (Status.hasCodeCompletion()) {
@@ -9112,7 +9089,7 @@ ParserResult<EnumDecl> Parser::parseDeclEnum(ParseDeclOptions Flags,
91129089

91139090
EnumDecl *ED = new (Context) EnumDecl(EnumLoc, EnumName, EnumNameLoc,
91149091
{ }, GenericParams, CurDeclContext);
9115-
ED->getAttrs() = Attributes;
9092+
ED->attachParsedAttrs(Attributes);
91169093

91179094
ContextChange CC(*this, ED);
91189095

@@ -9307,7 +9284,7 @@ Parser::parseDeclEnumCase(ParseDeclOptions Flags,
93079284
result->setImplicit(); // Parse error
93089285
}
93099286

9310-
result->getAttrs() = Attributes;
9287+
result->attachParsedAttrs(Attributes);
93119288
Elements.push_back(result);
93129289

93139290
// Continue through the comma-separated list.
@@ -9374,7 +9351,7 @@ ParserResult<StructDecl> Parser::parseDeclStruct(ParseDeclOptions Flags,
93749351
{ },
93759352
GenericParams,
93769353
CurDeclContext);
9377-
SD->getAttrs() = Attributes;
9354+
SD->attachParsedAttrs(Attributes);
93789355

93799356
ContextChange CC(*this, SD);
93809357

@@ -9463,7 +9440,7 @@ ParserResult<ClassDecl> Parser::parseDeclClass(ParseDeclOptions Flags,
94639440
ClassDecl *CD = new (Context) ClassDecl(ClassLoc, ClassName, ClassNameLoc,
94649441
{ }, GenericParams, CurDeclContext,
94659442
isExplicitActorDecl);
9466-
CD->getAttrs() = Attributes;
9443+
CD->attachParsedAttrs(Attributes);
94679444

94689445
// Parsed classes never have missing vtable entries.
94699446
CD->setHasMissingVTableEntries(false);
@@ -9642,7 +9619,7 @@ parseDeclProtocol(ParseDeclOptions Flags, DeclAttributes &Attributes) {
96429619
Context.AllocateCopy(PrimaryAssociatedTypeNames),
96439620
Context.AllocateCopy(InheritedProtocols), TrailingWhere);
96449621

9645-
Proto->getAttrs() = Attributes;
9622+
Proto->attachParsedAttrs(Attributes);
96469623
if (whereClauseHadCodeCompletion && CodeCompletionCallbacks)
96479624
CodeCompletionCallbacks->setParsedDecl(Proto);
96489625

@@ -9769,7 +9746,7 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
97699746
auto *const Subscript = SubscriptDecl::createParsed(
97709747
Context, StaticLoc, StaticSpelling, SubscriptLoc, Indices.get(), ArrowLoc,
97719748
ElementTy.get(), CurDeclContext, GenericParams);
9772-
Subscript->getAttrs() = Attributes;
9749+
Subscript->attachParsedAttrs(Attributes);
97739750

97749751
// Let the source file track the opaque return type mapping, if any.
97759752
if (ElementTy.get() && ElementTy.get()->hasOpaque() &&
@@ -9830,11 +9807,6 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
98309807
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess() ||
98319808
Status.hasCodeCompletion()));
98329809

9833-
// Set original declaration in `@differentiable` attributes.
9834-
for (auto *accessor : accessors.Accessors)
9835-
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
9836-
accessor);
9837-
98389810
// No need to setLocalDiscriminator because subscripts cannot
98399811
// validly appear outside of type decls.
98409812
return makeParserResult(Status, Subscript);
@@ -9944,7 +9916,7 @@ Parser::parseDeclInit(ParseDeclOptions Flags, DeclAttributes &Attributes) {
99449916
thrownTy, BodyParams, GenericParams,
99459917
CurDeclContext, FuncRetTy);
99469918
CD->setImplicitlyUnwrappedOptional(IUO);
9947-
CD->getAttrs() = Attributes;
9919+
CD->attachParsedAttrs(Attributes);
99489920

99499921
// Parse a 'where' clause if present.
99509922
if (Tok.is(tok::kw_where)) {
@@ -10033,7 +10005,7 @@ parseDeclDeinit(ParseDeclOptions Flags, DeclAttributes &Attributes) {
1003310005
auto *DD = new (Context) DestructorDecl(DestructorLoc, CurDeclContext);
1003410006
parseAbstractFunctionBody(DD);
1003510007

10036-
DD->getAttrs() = Attributes;
10008+
DD->attachParsedAttrs(Attributes);
1003710009

1003810010
// Reject 'destructor' functions outside of structs, enums, classes, or
1003910011
// extensions that provide objc implementations.
@@ -10240,7 +10212,7 @@ Parser::parseDeclOperatorImpl(SourceLoc OperatorLoc, Identifier Name,
1024010212

1024110213
diagnoseOperatorFixityAttributes(*this, Attributes, res);
1024210214

10243-
res->getAttrs() = Attributes;
10215+
res->attachParsedAttrs(Attributes);
1024410216
return makeParserResult(res);
1024510217
}
1024610218

@@ -10297,7 +10269,7 @@ Parser::parseDeclPrecedenceGroup(ParseDeclOptions flags,
1029710269
higherThanKeywordLoc, higherThan,
1029810270
lowerThanKeywordLoc, lowerThan,
1029910271
rbraceLoc);
10300-
result->getAttrs() = attributes;
10272+
result->attachParsedAttrs(attributes);
1030110273
return result;
1030210274
};
1030310275
auto createInvalid = [&](bool hasCodeCompletion) {
@@ -10552,7 +10524,7 @@ ParserResult<MacroDecl> Parser::parseDeclMacro(DeclAttributes &attributes) {
1055210524
auto *macro = new (Context) MacroDecl(
1055310525
macroLoc, macroFullName, macroNameLoc, genericParams, parameterList,
1055410526
arrowLoc, resultType, definition, CurDeclContext);
10555-
macro->getAttrs() = attributes;
10527+
macro->attachParsedAttrs(attributes);
1055610528

1055710529
// Parse a 'where' clause if present.
1055810530
if (Tok.is(tok::kw_where)) {
@@ -10588,7 +10560,7 @@ Parser::parseDeclMacroExpansion(ParseDeclOptions flags,
1058810560
auto *med = MacroExpansionDecl::create(
1058910561
CurDeclContext, poundLoc, macroNameRef, macroNameLoc, leftAngleLoc,
1059010562
Context.AllocateCopy(genericArgs), rightAngleLoc, argList);
10591-
med->getAttrs() = attributes;
10563+
med->attachParsedAttrs(attributes);
1059210564

1059310565
return makeParserResult(status, med);
1059410566
}

lib/Parse/ParseGeneric.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Parser::parseGenericParametersBeforeWhere(SourceLoc LAngleLoc,
148148
GenericParams.push_back(Param);
149149

150150
// Attach attributes.
151-
Param->getAttrs() = attributes;
151+
Param->attachParsedAttrs(attributes);
152152

153153
// Parse the comma, if the list continues.
154154
HasComma = consumeIf(tok::comma);

lib/Parse/ParsePattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ mapParsedParameters(Parser &parser,
571571
auto param = ParamDecl::createParsed(
572572
ctx, paramInfo.SpecifierLoc, argNameLoc, argName, paramNameLoc,
573573
paramName, paramInfo.DefaultArg, parser.CurDeclContext);
574-
param->getAttrs() = paramInfo.Attrs;
574+
param->attachParsedAttrs(paramInfo.Attrs);
575575

576576
bool parsingEnumElt
577577
= (paramContext == Parser::ParameterContextKind::EnumElement);

0 commit comments

Comments
 (0)