Skip to content

Commit c9f539e

Browse files
committed
[NFC] Extract autodiff parsing code to Decl.cpp
The `@differentiable` and `@derivative` attributes need a parent pointer. Move the code to populate it from Parser to AST so it can be more easily shared between the parsers. Done in preparation for similar code to be added for `@abi`.
1 parent 73377fe commit c9f539e

File tree

6 files changed

+57
-55
lines changed

6 files changed

+57
-55
lines changed

include/swift/AST/Attr.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2875,6 +2875,12 @@ template <typename ATTR, bool AllowInvalid> struct ToAttributeKind {
28752875
return cast<ATTR>(Attr);
28762876
return std::nullopt;
28772877
}
2878+
2879+
std::optional<ATTR *> operator()(DeclAttribute *Attr) {
2880+
if (isa<ATTR>(Attr) && (Attr->isValid() || AllowInvalid))
2881+
return cast<ATTR>(Attr);
2882+
return std::nullopt;
2883+
}
28782884
};
28792885

28802886
/// The @_allowFeatureSuppression(Foo, Bar) attribute. The feature
@@ -3047,17 +3053,25 @@ class DeclAttributes {
30473053
}
30483054

30493055
public:
3050-
template <typename ATTR, bool AllowInvalid>
3056+
template <typename ATTR, typename Iterator, bool AllowInvalid>
30513057
using AttributeKindRange =
3052-
OptionalTransformRange<iterator_range<const_iterator>,
3058+
OptionalTransformRange<iterator_range<Iterator>,
30533059
ToAttributeKind<ATTR, AllowInvalid>,
3054-
const_iterator>;
3060+
Iterator>;
3061+
3062+
/// Return a range with all attributes in DeclAttributes with AttrKind
3063+
/// ATTR.
3064+
template <typename ATTR, bool AllowInvalid = false>
3065+
AttributeKindRange<ATTR, const_iterator, AllowInvalid> getAttributes() const {
3066+
return AttributeKindRange<ATTR, const_iterator, AllowInvalid>(
3067+
make_range(begin(), end()), ToAttributeKind<ATTR, AllowInvalid>());
3068+
}
30553069

30563070
/// Return a range with all attributes in DeclAttributes with AttrKind
30573071
/// ATTR.
30583072
template <typename ATTR, bool AllowInvalid = false>
3059-
AttributeKindRange<ATTR, AllowInvalid> getAttributes() const {
3060-
return AttributeKindRange<ATTR, AllowInvalid>(
3073+
AttributeKindRange<ATTR, iterator, AllowInvalid> getAttributes() {
3074+
return AttributeKindRange<ATTR, iterator, AllowInvalid>(
30613075
make_range(begin(), end()), ToAttributeKind<ATTR, AllowInvalid>());
30623076
}
30633077

include/swift/AST/Decl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,11 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl>, public Swi
10221022
/// attribute macro expansion.
10231023
DeclAttributes getSemanticAttrs() const;
10241024

1025+
/// Set this declaration's attributes to the specified attribute list,
1026+
/// applying any post-processing logic appropriate for attributes parsed
1027+
/// from source code.
1028+
void attachParsedAttrs(DeclAttributes attrs);
1029+
10251030
/// True if this declaration provides an implementation for an imported
10261031
/// Objective-C declaration. This implies various restrictions and special
10271032
/// behaviors for it and, if it's an extension, its members.

lib/AST/Decl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,17 @@ DeclAttributes Decl::getSemanticAttrs() const {
407407
return getAttrs();
408408
}
409409

410+
void Decl::attachParsedAttrs(DeclAttributes attrs) {
411+
ASSERT(getAttrs().isEmpty() && "attaching when there are already attrs?");
412+
413+
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
414+
attr->setOriginalDeclaration(this);
415+
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
416+
attr->setOriginalDeclaration(this);
417+
418+
getAttrs() = attrs;
419+
}
420+
410421
void Decl::visitAuxiliaryDecls(
411422
AuxiliaryDeclCallback callback,
412423
bool visitFreestandingExpanded

lib/Parse/ParseDecl.cpp

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6030,20 +6030,6 @@ void Parser::consumeDecl(ParserPosition BeginParserPosition, bool IsTopLevel) {
60306030
}
60316031
}
60326032

6033-
/// Set the original declaration in `@differentiable` attributes.
6034-
///
6035-
/// Necessary because `Parser::parseNewDeclAttribute` (which calls
6036-
/// `Parser::parseDifferentiableAttribute`) does not have access to the
6037-
/// parent declaration of parsed attributes.
6038-
static void
6039-
setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
6040-
Decl *D) {
6041-
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
6042-
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
6043-
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
6044-
const_cast<DerivativeAttr *>(attr)->setOriginalDeclaration(D);
6045-
}
6046-
60476033
/// Determine the declaration parsing options to use when parsing the decl in
60486034
/// the given context.
60496035
static Parser::ParseDeclOptions getParseDeclOptions(DeclContext *DC) {
@@ -6469,7 +6455,6 @@ ParserStatus Parser::parseDecl(bool IsAtStartOfLineOrPreviousHadSemi,
64696455
Decl *D = DeclResult.get();
64706456
if (!HandlerAlreadyCalled)
64716457
Handler(D);
6472-
setOriginalDeclarationForDifferentiableAttributes(D->getAttrs(), D);
64736458
}
64746459

64756460
if (!DeclResult.isParseErrorOrHasCompletion()) {
@@ -6665,7 +6650,7 @@ ParserResult<ImportDecl> Parser::parseDeclImport(ParseDeclOptions Flags,
66656650

66666651
auto *ID = ImportDecl::create(Context, CurDeclContext, ImportLoc, Kind,
66676652
KindLoc, importPath.get());
6668-
ID->getAttrs() = Attributes;
6653+
ID->attachParsedAttrs(Attributes);
66696654
return DCC.fixupParserResult(ID);
66706655
}
66716656

@@ -7109,7 +7094,7 @@ Parser::parseDeclExtension(ParseDeclOptions Flags, DeclAttributes &Attributes) {
71097094
Context.AllocateCopy(Inherited),
71107095
CurDeclContext,
71117096
trailingWhereClause);
7112-
ext->getAttrs() = Attributes;
7097+
ext->attachParsedAttrs(Attributes);
71137098
if (trailingWhereHadCodeCompletion && CodeCompletionCallbacks)
71147099
CodeCompletionCallbacks->setParsedDecl(ext);
71157100

@@ -7444,7 +7429,7 @@ parseDeclTypeAlias(Parser::ParseDeclOptions Flags, DeclAttributes &Attributes) {
74447429
}
74457430

74467431
TAD->setUnderlyingTypeRepr(UnderlyingTy.getPtrOrNull());
7447-
TAD->getAttrs() = Attributes;
7432+
TAD->attachParsedAttrs(Attributes);
74487433

74497434
// Parse a 'where' clause if present.
74507435
if (Tok.is(tok::kw_where)) {
@@ -7561,7 +7546,7 @@ ParserResult<TypeDecl> Parser::parseDeclAssociatedType(Parser::ParseDeclOptions
75617546
auto assocType = AssociatedTypeDecl::createParsed(
75627547
Context, CurDeclContext, AssociatedTypeLoc, Id, IdLoc,
75637548
UnderlyingTy.getPtrOrNull(), TrailingWhere);
7564-
assocType->getAttrs() = Attributes;
7549+
assocType->attachParsedAttrs(Attributes);
75657550
if (!Inherited.empty())
75667551
assocType->setInherited(Context.AllocateCopy(Inherited));
75677552
return makeParserResult(Status, assocType);
@@ -7916,7 +7901,7 @@ bool Parser::parseAccessorAfterIntroducer(
79167901
auto *accessor = AccessorDecl::createParsed(
79177902
Context, Kind, storage, /*declLoc*/ Loc, /*accessorKeywordLoc*/ Loc,
79187903
param, asyncLoc, throwsLoc, thrownTy, CurDeclContext);
7919-
accessor->getAttrs() = Attributes;
7904+
accessor->attachParsedAttrs(Attributes);
79207905

79217906
// Collect this accessor and detect conflicts.
79227907
if (auto existingAccessor = accessors.add(accessor)) {
@@ -8195,7 +8180,7 @@ void Parser::parseExpandedAttributeList(SmallVectorImpl<ASTNode> &items,
81958180
// macro will attach the attribute list to.
81968181
MissingDecl *missing =
81978182
MissingDecl::create(Context, CurDeclContext, Tok.getLoc());
8198-
missing->getAttrs() = attributes;
8183+
missing->attachParsedAttrs(attributes);
81998184

82008185
items.push_back(ASTNode(missing));
82018186
return;
@@ -8325,11 +8310,6 @@ Parser::parseDeclVarGetSet(PatternBindingEntry &entry, ParseDeclOptions Flags,
83258310

83268311
accessors.record(*this, PrimaryVar, Invalid);
83278312

8328-
// Set original declaration in `@differentiable` attributes.
8329-
for (auto *accessor : accessors.Accessors)
8330-
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
8331-
accessor);
8332-
83338313
return makeParserResult(AccessorStatus, PrimaryVar);
83348314
}
83358315

@@ -8619,12 +8599,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
86198599
// Configure all vars with attributes, 'static' and parent pattern.
86208600
pattern->forEachVariable([&](VarDecl *VD) {
86218601
VD->setStatic(StaticLoc.isValid());
8622-
VD->getAttrs() = Attributes;
8602+
VD->attachParsedAttrs(Attributes);
86238603
VD->setTopLevelGlobal(topLevelDecl);
86248604

8625-
// Set original declaration in `@differentiable` attributes.
8626-
setOriginalDeclarationForDifferentiableAttributes(Attributes, VD);
8627-
86288605
Decls.push_back(VD);
86298606
if (hasOpaqueReturnTy && sf && !InInactiveClauseEnvironment
86308607
&& !InFreestandingMacroArgument) {
@@ -8944,7 +8921,7 @@ ParserResult<FuncDecl> Parser::parseDeclFunc(SourceLoc StaticLoc,
89448921
diagnoseOperatorFixityAttributes(*this, Attributes, FD);
89458922
// Add the attributes here so if we need them while parsing the body
89468923
// they are available.
8947-
FD->getAttrs() = Attributes;
8924+
FD->attachParsedAttrs(Attributes);
89488925

89498926
// Pass the function signature to code completion.
89508927
if (Status.hasCodeCompletion()) {
@@ -9135,7 +9112,7 @@ ParserResult<EnumDecl> Parser::parseDeclEnum(ParseDeclOptions Flags,
91359112

91369113
EnumDecl *ED = new (Context) EnumDecl(EnumLoc, EnumName, EnumNameLoc,
91379114
{ }, GenericParams, CurDeclContext);
9138-
ED->getAttrs() = Attributes;
9115+
ED->attachParsedAttrs(Attributes);
91399116

91409117
ContextChange CC(*this, ED);
91419118

@@ -9330,7 +9307,7 @@ Parser::parseDeclEnumCase(ParseDeclOptions Flags,
93309307
result->setImplicit(); // Parse error
93319308
}
93329309

9333-
result->getAttrs() = Attributes;
9310+
result->attachParsedAttrs(Attributes);
93349311
Elements.push_back(result);
93359312

93369313
// Continue through the comma-separated list.
@@ -9397,7 +9374,7 @@ ParserResult<StructDecl> Parser::parseDeclStruct(ParseDeclOptions Flags,
93979374
{ },
93989375
GenericParams,
93999376
CurDeclContext);
9400-
SD->getAttrs() = Attributes;
9377+
SD->attachParsedAttrs(Attributes);
94019378

94029379
ContextChange CC(*this, SD);
94039380

@@ -9486,7 +9463,7 @@ ParserResult<ClassDecl> Parser::parseDeclClass(ParseDeclOptions Flags,
94869463
ClassDecl *CD = new (Context) ClassDecl(ClassLoc, ClassName, ClassNameLoc,
94879464
{ }, GenericParams, CurDeclContext,
94889465
isExplicitActorDecl);
9489-
CD->getAttrs() = Attributes;
9466+
CD->attachParsedAttrs(Attributes);
94909467

94919468
// Parsed classes never have missing vtable entries.
94929469
CD->setHasMissingVTableEntries(false);
@@ -9665,7 +9642,7 @@ parseDeclProtocol(ParseDeclOptions Flags, DeclAttributes &Attributes) {
96659642
Context.AllocateCopy(PrimaryAssociatedTypeNames),
96669643
Context.AllocateCopy(InheritedProtocols), TrailingWhere);
96679644

9668-
Proto->getAttrs() = Attributes;
9645+
Proto->attachParsedAttrs(Attributes);
96699646
if (whereClauseHadCodeCompletion && CodeCompletionCallbacks)
96709647
CodeCompletionCallbacks->setParsedDecl(Proto);
96719648

@@ -9792,7 +9769,7 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
97929769
auto *const Subscript = SubscriptDecl::createParsed(
97939770
Context, StaticLoc, StaticSpelling, SubscriptLoc, Indices.get(), ArrowLoc,
97949771
ElementTy.get(), CurDeclContext, GenericParams);
9795-
Subscript->getAttrs() = Attributes;
9772+
Subscript->attachParsedAttrs(Attributes);
97969773

97979774
// Let the source file track the opaque return type mapping, if any.
97989775
if (ElementTy.get() && ElementTy.get()->hasOpaque() &&
@@ -9853,11 +9830,6 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
98539830
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess() ||
98549831
Status.hasCodeCompletion()));
98559832

9856-
// Set original declaration in `@differentiable` attributes.
9857-
for (auto *accessor : accessors.Accessors)
9858-
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
9859-
accessor);
9860-
98619833
// No need to setLocalDiscriminator because subscripts cannot
98629834
// validly appear outside of type decls.
98639835
return makeParserResult(Status, Subscript);
@@ -9967,7 +9939,7 @@ Parser::parseDeclInit(ParseDeclOptions Flags, DeclAttributes &Attributes) {
99679939
thrownTy, BodyParams, GenericParams,
99689940
CurDeclContext, FuncRetTy);
99699941
CD->setImplicitlyUnwrappedOptional(IUO);
9970-
CD->getAttrs() = Attributes;
9942+
CD->attachParsedAttrs(Attributes);
99719943

99729944
// Parse a 'where' clause if present.
99739945
if (Tok.is(tok::kw_where)) {
@@ -10056,7 +10028,7 @@ parseDeclDeinit(ParseDeclOptions Flags, DeclAttributes &Attributes) {
1005610028
auto *DD = new (Context) DestructorDecl(DestructorLoc, CurDeclContext);
1005710029
parseAbstractFunctionBody(DD);
1005810030

10059-
DD->getAttrs() = Attributes;
10031+
DD->attachParsedAttrs(Attributes);
1006010032

1006110033
// Reject 'destructor' functions outside of structs, enums, classes, or
1006210034
// extensions that provide objc implementations.
@@ -10263,7 +10235,7 @@ Parser::parseDeclOperatorImpl(SourceLoc OperatorLoc, Identifier Name,
1026310235

1026410236
diagnoseOperatorFixityAttributes(*this, Attributes, res);
1026510237

10266-
res->getAttrs() = Attributes;
10238+
res->attachParsedAttrs(Attributes);
1026710239
return makeParserResult(res);
1026810240
}
1026910241

@@ -10320,7 +10292,7 @@ Parser::parseDeclPrecedenceGroup(ParseDeclOptions flags,
1032010292
higherThanKeywordLoc, higherThan,
1032110293
lowerThanKeywordLoc, lowerThan,
1032210294
rbraceLoc);
10323-
result->getAttrs() = attributes;
10295+
result->attachParsedAttrs(attributes);
1032410296
return result;
1032510297
};
1032610298
auto createInvalid = [&](bool hasCodeCompletion) {
@@ -10575,7 +10547,7 @@ ParserResult<MacroDecl> Parser::parseDeclMacro(DeclAttributes &attributes) {
1057510547
auto *macro = new (Context) MacroDecl(
1057610548
macroLoc, macroFullName, macroNameLoc, genericParams, parameterList,
1057710549
arrowLoc, resultType, definition, CurDeclContext);
10578-
macro->getAttrs() = attributes;
10550+
macro->attachParsedAttrs(attributes);
1057910551

1058010552
// Parse a 'where' clause if present.
1058110553
if (Tok.is(tok::kw_where)) {
@@ -10611,7 +10583,7 @@ Parser::parseDeclMacroExpansion(ParseDeclOptions flags,
1061110583
auto *med = MacroExpansionDecl::create(
1061210584
CurDeclContext, poundLoc, macroNameRef, macroNameLoc, leftAngleLoc,
1061310585
Context.AllocateCopy(genericArgs), rightAngleLoc, argList);
10614-
med->getAttrs() = attributes;
10586+
med->attachParsedAttrs(attributes);
1061510587

1061610588
return makeParserResult(status, med);
1061710589
}

lib/Parse/ParseGeneric.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Parser::parseGenericParametersBeforeWhere(SourceLoc LAngleLoc,
148148
GenericParams.push_back(Param);
149149

150150
// Attach attributes.
151-
Param->getAttrs() = attributes;
151+
Param->attachParsedAttrs(attributes);
152152

153153
// Parse the comma, if the list continues.
154154
HasComma = consumeIf(tok::comma);

lib/Parse/ParsePattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ mapParsedParameters(Parser &parser,
571571
auto param = ParamDecl::createParsed(
572572
ctx, paramInfo.SpecifierLoc, argNameLoc, argName, paramNameLoc,
573573
paramName, paramInfo.DefaultArg, parser.CurDeclContext);
574-
param->getAttrs() = paramInfo.Attrs;
574+
param->attachParsedAttrs(paramInfo.Attrs);
575575

576576
bool parsingEnumElt
577577
= (paramContext == Parser::ParameterContextKind::EnumElement);

0 commit comments

Comments
 (0)