Skip to content

Commit b0b459a

Browse files
committed
[NFC] Extract autodiff parsing code to Decl.cpp
The `@differentiable` and `@derivative` attributes need a parent pointer. Move the code to populate it from Parser to AST so it can be more easily shared between the parsers. Done in preparation for similar code to be added for `@abi`.
1 parent db0dc20 commit b0b459a

File tree

6 files changed

+57
-55
lines changed

6 files changed

+57
-55
lines changed

include/swift/AST/Attr.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,6 +2871,12 @@ template <typename ATTR, bool AllowInvalid> struct ToAttributeKind {
28712871
return cast<ATTR>(Attr);
28722872
return std::nullopt;
28732873
}
2874+
2875+
std::optional<ATTR *> operator()(DeclAttribute *Attr) {
2876+
if (isa<ATTR>(Attr) && (Attr->isValid() || AllowInvalid))
2877+
return cast<ATTR>(Attr);
2878+
return std::nullopt;
2879+
}
28742880
};
28752881

28762882
/// The @_allowFeatureSuppression(Foo, Bar) attribute. The feature
@@ -3043,17 +3049,25 @@ class DeclAttributes {
30433049
}
30443050

30453051
public:
3046-
template <typename ATTR, bool AllowInvalid>
3052+
template <typename ATTR, typename Iterator, bool AllowInvalid>
30473053
using AttributeKindRange =
3048-
OptionalTransformRange<iterator_range<const_iterator>,
3054+
OptionalTransformRange<iterator_range<Iterator>,
30493055
ToAttributeKind<ATTR, AllowInvalid>,
3050-
const_iterator>;
3056+
Iterator>;
3057+
3058+
/// Return a range with all attributes in DeclAttributes with AttrKind
3059+
/// ATTR.
3060+
template <typename ATTR, bool AllowInvalid = false>
3061+
AttributeKindRange<ATTR, const_iterator, AllowInvalid> getAttributes() const {
3062+
return AttributeKindRange<ATTR, const_iterator, AllowInvalid>(
3063+
make_range(begin(), end()), ToAttributeKind<ATTR, AllowInvalid>());
3064+
}
30513065

30523066
/// Return a range with all attributes in DeclAttributes with AttrKind
30533067
/// ATTR.
30543068
template <typename ATTR, bool AllowInvalid = false>
3055-
AttributeKindRange<ATTR, AllowInvalid> getAttributes() const {
3056-
return AttributeKindRange<ATTR, AllowInvalid>(
3069+
AttributeKindRange<ATTR, iterator, AllowInvalid> getAttributes() {
3070+
return AttributeKindRange<ATTR, iterator, AllowInvalid>(
30573071
make_range(begin(), end()), ToAttributeKind<ATTR, AllowInvalid>());
30583072
}
30593073

include/swift/AST/Decl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,11 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl>, public Swi
10221022
/// attribute macro expansion.
10231023
DeclAttributes getSemanticAttrs() const;
10241024

1025+
/// Set this declaration's attributes to the specified attribute list,
1026+
/// applying any post-processing logic appropriate for attributes parsed
1027+
/// from source code.
1028+
void attachParsedAttrs(DeclAttributes attrs);
1029+
10251030
/// True if this declaration provides an implementation for an imported
10261031
/// Objective-C declaration. This implies various restrictions and special
10271032
/// behaviors for it and, if it's an extension, its members.

lib/AST/Decl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,17 @@ DeclAttributes Decl::getSemanticAttrs() const {
407407
return getAttrs();
408408
}
409409

410+
void Decl::attachParsedAttrs(DeclAttributes attrs) {
411+
ASSERT(getAttrs().isEmpty() && "attaching when there are already attrs?");
412+
413+
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
414+
attr->setOriginalDeclaration(this);
415+
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
416+
attr->setOriginalDeclaration(this);
417+
418+
getAttrs() = attrs;
419+
}
420+
410421
void Decl::visitAuxiliaryDecls(
411422
AuxiliaryDeclCallback callback,
412423
bool visitFreestandingExpanded

lib/Parse/ParseDecl.cpp

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5968,20 +5968,6 @@ void Parser::consumeDecl(ParserPosition BeginParserPosition, bool IsTopLevel) {
59685968
}
59695969
}
59705970

5971-
/// Set the original declaration in `@differentiable` attributes.
5972-
///
5973-
/// Necessary because `Parser::parseNewDeclAttribute` (which calls
5974-
/// `Parser::parseDifferentiableAttribute`) does not have access to the
5975-
/// parent declaration of parsed attributes.
5976-
static void
5977-
setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
5978-
Decl *D) {
5979-
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
5980-
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
5981-
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
5982-
const_cast<DerivativeAttr *>(attr)->setOriginalDeclaration(D);
5983-
}
5984-
59855971
/// Determine the declaration parsing options to use when parsing the decl in
59865972
/// the given context.
59875973
static Parser::ParseDeclOptions getParseDeclOptions(DeclContext *DC) {
@@ -6407,7 +6393,6 @@ ParserStatus Parser::parseDecl(bool IsAtStartOfLineOrPreviousHadSemi,
64076393
Decl *D = DeclResult.get();
64086394
if (!HandlerAlreadyCalled)
64096395
Handler(D);
6410-
setOriginalDeclarationForDifferentiableAttributes(D->getAttrs(), D);
64116396
}
64126397

64136398
if (!DeclResult.isParseErrorOrHasCompletion()) {
@@ -6603,7 +6588,7 @@ ParserResult<ImportDecl> Parser::parseDeclImport(ParseDeclOptions Flags,
66036588

66046589
auto *ID = ImportDecl::create(Context, CurDeclContext, ImportLoc, Kind,
66056590
KindLoc, importPath.get());
6606-
ID->getAttrs() = Attributes;
6591+
ID->attachParsedAttrs(Attributes);
66076592
return DCC.fixupParserResult(ID);
66086593
}
66096594

@@ -7047,7 +7032,7 @@ Parser::parseDeclExtension(ParseDeclOptions Flags, DeclAttributes &Attributes) {
70477032
Context.AllocateCopy(Inherited),
70487033
CurDeclContext,
70497034
trailingWhereClause);
7050-
ext->getAttrs() = Attributes;
7035+
ext->attachParsedAttrs(Attributes);
70517036
if (trailingWhereHadCodeCompletion && CodeCompletionCallbacks)
70527037
CodeCompletionCallbacks->setParsedDecl(ext);
70537038

@@ -7382,7 +7367,7 @@ parseDeclTypeAlias(Parser::ParseDeclOptions Flags, DeclAttributes &Attributes) {
73827367
}
73837368

73847369
TAD->setUnderlyingTypeRepr(UnderlyingTy.getPtrOrNull());
7385-
TAD->getAttrs() = Attributes;
7370+
TAD->attachParsedAttrs(Attributes);
73867371

73877372
// Parse a 'where' clause if present.
73887373
if (Tok.is(tok::kw_where)) {
@@ -7499,7 +7484,7 @@ ParserResult<TypeDecl> Parser::parseDeclAssociatedType(Parser::ParseDeclOptions
74997484
auto assocType = AssociatedTypeDecl::createParsed(
75007485
Context, CurDeclContext, AssociatedTypeLoc, Id, IdLoc,
75017486
UnderlyingTy.getPtrOrNull(), TrailingWhere);
7502-
assocType->getAttrs() = Attributes;
7487+
assocType->attachParsedAttrs(Attributes);
75037488
if (!Inherited.empty())
75047489
assocType->setInherited(Context.AllocateCopy(Inherited));
75057490
return makeParserResult(Status, assocType);
@@ -7854,7 +7839,7 @@ bool Parser::parseAccessorAfterIntroducer(
78547839
auto *accessor = AccessorDecl::createParsed(
78557840
Context, Kind, storage, /*declLoc*/ Loc, /*accessorKeywordLoc*/ Loc,
78567841
param, asyncLoc, throwsLoc, thrownTy, CurDeclContext);
7857-
accessor->getAttrs() = Attributes;
7842+
accessor->attachParsedAttrs(Attributes);
78587843

78597844
// Collect this accessor and detect conflicts.
78607845
if (auto existingAccessor = accessors.add(accessor)) {
@@ -8133,7 +8118,7 @@ void Parser::parseExpandedAttributeList(SmallVectorImpl<ASTNode> &items,
81338118
// macro will attach the attribute list to.
81348119
MissingDecl *missing =
81358120
MissingDecl::create(Context, CurDeclContext, Tok.getLoc());
8136-
missing->getAttrs() = attributes;
8121+
missing->attachParsedAttrs(attributes);
81378122

81388123
items.push_back(ASTNode(missing));
81398124
return;
@@ -8263,11 +8248,6 @@ Parser::parseDeclVarGetSet(PatternBindingEntry &entry, ParseDeclOptions Flags,
82638248

82648249
accessors.record(*this, PrimaryVar, Invalid);
82658250

8266-
// Set original declaration in `@differentiable` attributes.
8267-
for (auto *accessor : accessors.Accessors)
8268-
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
8269-
accessor);
8270-
82718251
return makeParserResult(AccessorStatus, PrimaryVar);
82728252
}
82738253

@@ -8557,12 +8537,9 @@ Parser::parseDeclVar(ParseDeclOptions Flags,
85578537
// Configure all vars with attributes, 'static' and parent pattern.
85588538
pattern->forEachVariable([&](VarDecl *VD) {
85598539
VD->setStatic(StaticLoc.isValid());
8560-
VD->getAttrs() = Attributes;
8540+
VD->attachParsedAttrs(Attributes);
85618541
VD->setTopLevelGlobal(topLevelDecl);
85628542

8563-
// Set original declaration in `@differentiable` attributes.
8564-
setOriginalDeclarationForDifferentiableAttributes(Attributes, VD);
8565-
85668543
Decls.push_back(VD);
85678544
if (hasOpaqueReturnTy && sf && !InInactiveClauseEnvironment
85688545
&& !InFreestandingMacroArgument) {
@@ -8882,7 +8859,7 @@ ParserResult<FuncDecl> Parser::parseDeclFunc(SourceLoc StaticLoc,
88828859
diagnoseOperatorFixityAttributes(*this, Attributes, FD);
88838860
// Add the attributes here so if we need them while parsing the body
88848861
// they are available.
8885-
FD->getAttrs() = Attributes;
8862+
FD->attachParsedAttrs(Attributes);
88868863

88878864
// Pass the function signature to code completion.
88888865
if (Status.hasCodeCompletion()) {
@@ -9073,7 +9050,7 @@ ParserResult<EnumDecl> Parser::parseDeclEnum(ParseDeclOptions Flags,
90739050

90749051
EnumDecl *ED = new (Context) EnumDecl(EnumLoc, EnumName, EnumNameLoc,
90759052
{ }, GenericParams, CurDeclContext);
9076-
ED->getAttrs() = Attributes;
9053+
ED->attachParsedAttrs(Attributes);
90779054

90789055
ContextChange CC(*this, ED);
90799056

@@ -9268,7 +9245,7 @@ Parser::parseDeclEnumCase(ParseDeclOptions Flags,
92689245
result->setImplicit(); // Parse error
92699246
}
92709247

9271-
result->getAttrs() = Attributes;
9248+
result->attachParsedAttrs(Attributes);
92729249
Elements.push_back(result);
92739250

92749251
// Continue through the comma-separated list.
@@ -9335,7 +9312,7 @@ ParserResult<StructDecl> Parser::parseDeclStruct(ParseDeclOptions Flags,
93359312
{ },
93369313
GenericParams,
93379314
CurDeclContext);
9338-
SD->getAttrs() = Attributes;
9315+
SD->attachParsedAttrs(Attributes);
93399316

93409317
ContextChange CC(*this, SD);
93419318

@@ -9424,7 +9401,7 @@ ParserResult<ClassDecl> Parser::parseDeclClass(ParseDeclOptions Flags,
94249401
ClassDecl *CD = new (Context) ClassDecl(ClassLoc, ClassName, ClassNameLoc,
94259402
{ }, GenericParams, CurDeclContext,
94269403
isExplicitActorDecl);
9427-
CD->getAttrs() = Attributes;
9404+
CD->attachParsedAttrs(Attributes);
94289405

94299406
// Parsed classes never have missing vtable entries.
94309407
CD->setHasMissingVTableEntries(false);
@@ -9603,7 +9580,7 @@ parseDeclProtocol(ParseDeclOptions Flags, DeclAttributes &Attributes) {
96039580
Context.AllocateCopy(PrimaryAssociatedTypeNames),
96049581
Context.AllocateCopy(InheritedProtocols), TrailingWhere);
96059582

9606-
Proto->getAttrs() = Attributes;
9583+
Proto->attachParsedAttrs(Attributes);
96079584
if (whereClauseHadCodeCompletion && CodeCompletionCallbacks)
96089585
CodeCompletionCallbacks->setParsedDecl(Proto);
96099586

@@ -9730,7 +9707,7 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
97309707
auto *const Subscript = SubscriptDecl::createParsed(
97319708
Context, StaticLoc, StaticSpelling, SubscriptLoc, Indices.get(), ArrowLoc,
97329709
ElementTy.get(), CurDeclContext, GenericParams);
9733-
Subscript->getAttrs() = Attributes;
9710+
Subscript->attachParsedAttrs(Attributes);
97349711

97359712
// Let the source file track the opaque return type mapping, if any.
97369713
if (ElementTy.get() && ElementTy.get()->hasOpaque() &&
@@ -9791,11 +9768,6 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc,
97919768
accessors.record(*this, Subscript, (Invalid || !Status.isSuccess() ||
97929769
Status.hasCodeCompletion()));
97939770

9794-
// Set original declaration in `@differentiable` attributes.
9795-
for (auto *accessor : accessors.Accessors)
9796-
setOriginalDeclarationForDifferentiableAttributes(accessor->getAttrs(),
9797-
accessor);
9798-
97999771
// No need to setLocalDiscriminator because subscripts cannot
98009772
// validly appear outside of type decls.
98019773
return makeParserResult(Status, Subscript);
@@ -9905,7 +9877,7 @@ Parser::parseDeclInit(ParseDeclOptions Flags, DeclAttributes &Attributes) {
99059877
thrownTy, BodyParams, GenericParams,
99069878
CurDeclContext, FuncRetTy);
99079879
CD->setImplicitlyUnwrappedOptional(IUO);
9908-
CD->getAttrs() = Attributes;
9880+
CD->attachParsedAttrs(Attributes);
99099881

99109882
// Parse a 'where' clause if present.
99119883
if (Tok.is(tok::kw_where)) {
@@ -9994,7 +9966,7 @@ parseDeclDeinit(ParseDeclOptions Flags, DeclAttributes &Attributes) {
99949966
auto *DD = new (Context) DestructorDecl(DestructorLoc, CurDeclContext);
99959967
parseAbstractFunctionBody(DD);
99969968

9997-
DD->getAttrs() = Attributes;
9969+
DD->attachParsedAttrs(Attributes);
99989970

99999971
// Reject 'destructor' functions outside of structs, enums, classes, or
100009972
// extensions that provide objc implementations.
@@ -10201,7 +10173,7 @@ Parser::parseDeclOperatorImpl(SourceLoc OperatorLoc, Identifier Name,
1020110173

1020210174
diagnoseOperatorFixityAttributes(*this, Attributes, res);
1020310175

10204-
res->getAttrs() = Attributes;
10176+
res->attachParsedAttrs(Attributes);
1020510177
return makeParserResult(res);
1020610178
}
1020710179

@@ -10258,7 +10230,7 @@ Parser::parseDeclPrecedenceGroup(ParseDeclOptions flags,
1025810230
higherThanKeywordLoc, higherThan,
1025910231
lowerThanKeywordLoc, lowerThan,
1026010232
rbraceLoc);
10261-
result->getAttrs() = attributes;
10233+
result->attachParsedAttrs(attributes);
1026210234
return result;
1026310235
};
1026410236
auto createInvalid = [&](bool hasCodeCompletion) {
@@ -10513,7 +10485,7 @@ ParserResult<MacroDecl> Parser::parseDeclMacro(DeclAttributes &attributes) {
1051310485
auto *macro = new (Context) MacroDecl(
1051410486
macroLoc, macroFullName, macroNameLoc, genericParams, parameterList,
1051510487
arrowLoc, resultType, definition, CurDeclContext);
10516-
macro->getAttrs() = attributes;
10488+
macro->attachParsedAttrs(attributes);
1051710489

1051810490
// Parse a 'where' clause if present.
1051910491
if (Tok.is(tok::kw_where)) {
@@ -10549,7 +10521,7 @@ Parser::parseDeclMacroExpansion(ParseDeclOptions flags,
1054910521
auto *med = MacroExpansionDecl::create(
1055010522
CurDeclContext, poundLoc, macroNameRef, macroNameLoc, leftAngleLoc,
1055110523
Context.AllocateCopy(genericArgs), rightAngleLoc, argList);
10552-
med->getAttrs() = attributes;
10524+
med->attachParsedAttrs(attributes);
1055310525

1055410526
return makeParserResult(status, med);
1055510527
}

lib/Parse/ParseGeneric.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Parser::parseGenericParametersBeforeWhere(SourceLoc LAngleLoc,
148148
GenericParams.push_back(Param);
149149

150150
// Attach attributes.
151-
Param->getAttrs() = attributes;
151+
Param->attachParsedAttrs(attributes);
152152

153153
// Parse the comma, if the list continues.
154154
HasComma = consumeIf(tok::comma);

lib/Parse/ParsePattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ mapParsedParameters(Parser &parser,
571571
auto param = ParamDecl::createParsed(
572572
ctx, paramInfo.SpecifierLoc, argNameLoc, argName, paramNameLoc,
573573
paramName, paramInfo.DefaultArg, parser.CurDeclContext);
574-
param->getAttrs() = paramInfo.Attrs;
574+
param->attachParsedAttrs(paramInfo.Attrs);
575575

576576
bool parsingEnumElt
577577
= (paramContext == Parser::ParameterContextKind::EnumElement);

0 commit comments

Comments
 (0)