Skip to content

Commit 2a2cf91

Browse files
committed
Add support for marking a _specialize attribute as SPI
``` @_specialize(exported: true, spi: SPIGroupName, where T == Int) public func myFunc() { } ``` The specialized entry point is only visible for modules that import using `_spi(SPIGroupName) import ModuleDefiningMyFunc `. rdar://64993425
1 parent b994bf3 commit 2a2cf91

25 files changed

+508
-80
lines changed

include/swift/AST/Attr.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,8 +1408,11 @@ class SynthesizedProtocolAttr : public DeclAttribute {
14081408

14091409
/// The @_specialize attribute, which forces specialization on the specified
14101410
/// type list.
1411-
class SpecializeAttr : public DeclAttribute {
1411+
class SpecializeAttr final
1412+
: public DeclAttribute,
1413+
private llvm::TrailingObjects<SpecializeAttr, Identifier> {
14121414
friend class SpecializeAttrTargetDeclRequest;
1415+
friend TrailingObjects;
14131416

14141417
public:
14151418
// NOTE: When adding new kinds, you must update the inline bitfield macro.
@@ -1425,32 +1428,45 @@ class SpecializeAttr : public DeclAttribute {
14251428
DeclNameRef targetFunctionName;
14261429
LazyMemberLoader *resolver = nullptr;
14271430
uint64_t resolverContextData;
1431+
size_t numSPIGroups;
14281432

14291433
SpecializeAttr(SourceLoc atLoc, SourceRange Range,
14301434
TrailingWhereClause *clause, bool exported,
1431-
SpecializationKind kind,
1432-
GenericSignature specializedSignature,
1433-
DeclNameRef targetFunctionName);
1435+
SpecializationKind kind, GenericSignature specializedSignature,
1436+
DeclNameRef targetFunctionName,
1437+
ArrayRef<Identifier> spiGroups);
14341438

14351439
public:
14361440
static SpecializeAttr *create(ASTContext &Ctx, SourceLoc atLoc,
14371441
SourceRange Range, TrailingWhereClause *clause,
14381442
bool exported, SpecializationKind kind,
14391443
DeclNameRef targetFunctionName,
1444+
ArrayRef<Identifier> spiGroups,
14401445
GenericSignature specializedSignature
14411446
= nullptr);
14421447

14431448
static SpecializeAttr *create(ASTContext &ctx, bool exported,
14441449
SpecializationKind kind,
1450+
ArrayRef<Identifier> spiGroups,
14451451
GenericSignature specializedSignature,
14461452
DeclNameRef replacedFunction);
14471453

14481454
static SpecializeAttr *create(ASTContext &ctx, bool exported,
14491455
SpecializationKind kind,
1456+
ArrayRef<Identifier> spiGroups,
14501457
GenericSignature specializedSignature,
14511458
DeclNameRef replacedFunction,
14521459
LazyMemberLoader *resolver, uint64_t data);
14531460

1461+
/// Name of SPIs declared by the attribute.
1462+
///
1463+
/// Note: A single SPI name per attribute is currently supported but this
1464+
/// may change with the syntax change.
1465+
ArrayRef<Identifier> getSPIGroups() const {
1466+
return { this->template getTrailingObjects<Identifier>(),
1467+
numSPIGroups };
1468+
}
1469+
14541470
TrailingWhereClause *getTrailingWhereClause() const;
14551471

14561472
GenericSignature getSpecializedSignature() const {

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,8 @@ ERROR(attr_specialize_expected_partial_or_full,none,
15831583
"expected 'partial' or 'full' as values of the 'kind' parameter in '_specialize' attribute", ())
15841584
ERROR(attr_specialize_expected_function,none,
15851585
"expected a function name as the value of the 'target' parameter in '_specialize' attribute", ())
1586+
ERROR(attr_specialize_expected_spi_name,none,
1587+
"expected an SPI identifier as the value of the 'spi' parameter in '_specialize' attribute", ())
15861588

15871589
// _implements
15881590
ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,

include/swift/AST/Module.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,13 @@ class ModuleDecl : public DeclContext, public TypeDecl {
578578
const ModuleDecl *importedModule,
579579
llvm::SmallSetVector<Identifier, 4> &spiGroups) const;
580580

581+
// Is \p attr accessible as an explictly imported SPI from this module?
582+
bool isImportedAsSPI(const SpecializeAttr *attr,
583+
const ValueDecl *targetDecl) const;
584+
585+
// Is \p spiGroup accessible as an explictly imported SPI from this module?
586+
bool isImportedAsSPI(Identifier spiGroup, const ModuleDecl *fromModule) const;
587+
581588
/// \sa getImportedModules
582589
enum class ImportFilterKind {
583590
/// Include imports declared with `@_exported`.

include/swift/Parse/Parser.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,16 +1017,19 @@ class Parser {
10171017
bool parseSpecializeAttribute(
10181018
swift::tok ClosingBrace, SourceLoc AtLoc, SourceLoc Loc,
10191019
SpecializeAttr *&Attr,
1020-
llvm::function_ref<bool(Parser &)> parseSILTargetName = [](Parser &) {
1021-
return false;
1022-
});
1020+
llvm::function_ref<bool(Parser &)> parseSILTargetName =
1021+
[](Parser &) { return false; },
1022+
llvm::function_ref<bool(Parser &)> parseSILSIPModule =
1023+
[](Parser &) { return false; });
10231024

10241025
/// Parse the arguments inside the @_specialize attribute
10251026
bool parseSpecializeAttributeArguments(
10261027
swift::tok ClosingBrace, bool &DiscardAttribute, Optional<bool> &Exported,
10271028
Optional<SpecializeAttr::SpecializationKind> &Kind,
10281029
TrailingWhereClause *&TrailingWhereClause, DeclNameRef &targetFunction,
1029-
llvm::function_ref<bool(Parser &)> parseSILTargetName);
1030+
SmallVectorImpl<Identifier> &spiGroups,
1031+
llvm::function_ref<bool(Parser &)> parseSILTargetName,
1032+
llvm::function_ref<bool(Parser &)> parseSILSIPModule);
10301033

10311034
/// Parse the @_implements attribute.
10321035
/// \p Attr is where to store the parsed attribute

include/swift/SIL/SILFunction.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class SILSpecializeAttr final {
7272
static SILSpecializeAttr *create(SILModule &M,
7373
GenericSignature specializedSignature,
7474
bool exported, SpecializationKind kind,
75-
SILFunction *target);
75+
SILFunction *target, Identifier spiGroup,
76+
const ModuleDecl *spiModule);
7677

7778
bool isExported() const {
7879
return exported;
@@ -102,17 +103,28 @@ class SILSpecializeAttr final {
102103
return targetFunction;
103104
}
104105

106+
Identifier getSPIGroup() const {
107+
return spiGroup;
108+
}
109+
110+
const ModuleDecl *getSPIModule() const {
111+
return spiModule;
112+
}
113+
105114
void print(llvm::raw_ostream &OS) const;
106115

107116
private:
108117
SpecializationKind kind;
109118
bool exported;
110119
GenericSignature specializedSignature;
120+
Identifier spiGroup;
121+
const ModuleDecl *spiModule = nullptr;
111122
SILFunction *F = nullptr;
112123
SILFunction *targetFunction = nullptr;
113124

114125
SILSpecializeAttr(bool exported, SpecializationKind kind,
115-
GenericSignature specializedSignature, SILFunction *target);
126+
GenericSignature specializedSignature, SILFunction *target,
127+
Identifier spiGroup, const ModuleDecl *spiModule);
116128
};
117129

118130
/// SILFunction - A function body that has been lowered to SIL. This consists of

lib/AST/Attr.cpp

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -898,12 +898,20 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
898898
}
899899

900900
case DAK_Specialize: {
901-
Printer << "@" << getAttrName() << "(";
902901
auto *attr = cast<SpecializeAttr>(this);
902+
// Don't print the _specialize attribute if it is marked spi and we are
903+
// asked to skip SPI.
904+
if (!Options.PrintSPIs && !attr->getSPIGroups().empty())
905+
return false;
906+
907+
Printer << "@" << getAttrName() << "(";
903908
auto exported = attr->isExported() ? "true" : "false";
904909
auto kind = attr->isPartialSpecialization() ? "partial" : "full";
905910
auto target = attr->getTargetFunctionName();
906911
Printer << "exported: "<< exported << ", ";
912+
for (auto id : attr->getSPIGroups()) {
913+
Printer << "spi: " << id << ", ";
914+
}
907915
Printer << "kind: " << kind << ", ";
908916
if (target)
909917
Printer << "target: " << target << ", ";
@@ -1558,16 +1566,17 @@ const AvailableAttr *AvailableAttr::isUnavailable(const Decl *D) {
15581566
}
15591567

15601568
SpecializeAttr::SpecializeAttr(SourceLoc atLoc, SourceRange range,
1561-
TrailingWhereClause *clause,
1562-
bool exported,
1569+
TrailingWhereClause *clause, bool exported,
15631570
SpecializationKind kind,
15641571
GenericSignature specializedSignature,
1565-
DeclNameRef targetFunctionName)
1572+
DeclNameRef targetFunctionName,
1573+
ArrayRef<Identifier> spiGroups)
15661574
: DeclAttribute(DAK_Specialize, atLoc, range,
15671575
/*Implicit=*/clause == nullptr),
1568-
trailingWhereClause(clause),
1569-
specializedSignature(specializedSignature),
1570-
targetFunctionName(targetFunctionName) {
1576+
trailingWhereClause(clause), specializedSignature(specializedSignature),
1577+
targetFunctionName(targetFunctionName), numSPIGroups(spiGroups.size()) {
1578+
std::uninitialized_copy(spiGroups.begin(), spiGroups.end(),
1579+
getTrailingObjects<Identifier>());
15711580
Bits.SpecializeAttr.exported = exported;
15721581
Bits.SpecializeAttr.kind = unsigned(kind);
15731582
}
@@ -1579,32 +1588,38 @@ TrailingWhereClause *SpecializeAttr::getTrailingWhereClause() const {
15791588
SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
15801589
SourceRange range,
15811590
TrailingWhereClause *clause,
1582-
bool exported,
1583-
SpecializationKind kind,
1591+
bool exported, SpecializationKind kind,
15841592
DeclNameRef targetFunctionName,
1593+
ArrayRef<Identifier> spiGroups,
15851594
GenericSignature specializedSignature) {
1586-
return new (Ctx) SpecializeAttr(atLoc, range, clause, exported, kind,
1587-
specializedSignature, targetFunctionName);
1595+
unsigned size = totalSizeToAlloc<Identifier>(spiGroups.size());
1596+
void *mem = Ctx.Allocate(size, alignof(SpecializeAttr));
1597+
return new (mem)
1598+
SpecializeAttr(atLoc, range, clause, exported, kind, specializedSignature,
1599+
targetFunctionName, spiGroups);
15881600
}
15891601

15901602
SpecializeAttr *SpecializeAttr::create(ASTContext &ctx, bool exported,
1591-
SpecializationKind kind,
1592-
GenericSignature specializedSignature,
1593-
DeclNameRef targetFunctionName) {
1594-
return new (ctx)
1603+
SpecializationKind kind,
1604+
ArrayRef<Identifier> spiGroups,
1605+
GenericSignature specializedSignature,
1606+
DeclNameRef targetFunctionName) {
1607+
unsigned size = totalSizeToAlloc<Identifier>(spiGroups.size());
1608+
void *mem = ctx.Allocate(size, alignof(SpecializeAttr));
1609+
return new (mem)
15951610
SpecializeAttr(SourceLoc(), SourceRange(), nullptr, exported, kind,
1596-
specializedSignature, targetFunctionName);
1611+
specializedSignature, targetFunctionName, spiGroups);
15971612
}
15981613

1599-
SpecializeAttr *SpecializeAttr::create(ASTContext &ctx, bool exported,
1600-
SpecializationKind kind,
1601-
GenericSignature specializedSignature,
1602-
DeclNameRef targetFunctionName,
1603-
LazyMemberLoader *resolver,
1604-
uint64_t data) {
1605-
auto *attr = new (ctx)
1614+
SpecializeAttr *SpecializeAttr::create(
1615+
ASTContext &ctx, bool exported, SpecializationKind kind,
1616+
ArrayRef<Identifier> spiGroups, GenericSignature specializedSignature,
1617+
DeclNameRef targetFunctionName, LazyMemberLoader *resolver, uint64_t data) {
1618+
unsigned size = totalSizeToAlloc<Identifier>(spiGroups.size());
1619+
void *mem = ctx.Allocate(size, alignof(SpecializeAttr));
1620+
auto *attr = new (mem)
16061621
SpecializeAttr(SourceLoc(), SourceRange(), nullptr, exported, kind,
1607-
specializedSignature, targetFunctionName);
1622+
specializedSignature, targetFunctionName, spiGroups);
16081623
attr->resolver = resolver;
16091624
attr->resolverContextData = data;
16101625
return attr;

lib/AST/Module.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,6 +2034,31 @@ bool SourceFile::isImportedAsSPI(const ValueDecl *targetDecl) const {
20342034
return false;
20352035
}
20362036

2037+
bool ModuleDecl::isImportedAsSPI(const SpecializeAttr *attr,
2038+
const ValueDecl *targetDecl) const {
2039+
auto targetModule = targetDecl->getModuleContext();
2040+
llvm::SmallSetVector<Identifier, 4> importedSPIGroups;
2041+
lookupImportedSPIGroups(targetModule, importedSPIGroups);
2042+
if (importedSPIGroups.empty()) return false;
2043+
2044+
auto declSPIGroups = attr->getSPIGroups();
2045+
2046+
for (auto declSPI : declSPIGroups)
2047+
if (importedSPIGroups.count(declSPI))
2048+
return true;
2049+
2050+
return false;
2051+
}
2052+
2053+
bool ModuleDecl::isImportedAsSPI(Identifier spiGroup,
2054+
const ModuleDecl *fromModule) const {
2055+
llvm::SmallSetVector<Identifier, 4> importedSPIGroups;
2056+
lookupImportedSPIGroups(fromModule, importedSPIGroups);
2057+
if (importedSPIGroups.empty())
2058+
return false;
2059+
return importedSPIGroups.count(spiGroup);
2060+
}
2061+
20372062
bool Decl::isSPI() const {
20382063
return !getSPIGroups().empty();
20392064
}

lib/Parse/ParseDecl.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,9 @@ bool Parser::parseSpecializeAttributeArguments(
582582
swift::tok ClosingBrace, bool &DiscardAttribute, Optional<bool> &Exported,
583583
Optional<SpecializeAttr::SpecializationKind> &Kind,
584584
swift::TrailingWhereClause *&TrailingWhereClause,
585-
DeclNameRef &targetFunction,
586-
llvm::function_ref<bool(Parser &)> parseSILTargetName) {
585+
DeclNameRef &targetFunction, SmallVectorImpl<Identifier> &spiGroups,
586+
llvm::function_ref<bool(Parser &)> parseSILTargetName,
587+
llvm::function_ref<bool(Parser &)> parseSILSIPModule) {
587588
SyntaxParsingContext ContentContext(SyntaxContext,
588589
SyntaxKind::SpecializeAttributeSpecList);
589590
// Parse optional "exported" and "kind" labeled parameters.
@@ -595,7 +596,8 @@ bool Parser::parseSpecializeAttributeArguments(
595596
? SyntaxKind::TargetFunctionEntry
596597
: SyntaxKind::LabeledSpecializeEntry);
597598
if (ParamLabel != "exported" && ParamLabel != "kind" &&
598-
ParamLabel != "target") {
599+
ParamLabel != "target" && ParamLabel != "spi" &&
600+
ParamLabel != "spiModule") {
599601
diagnose(Tok.getLoc(), diag::attr_specialize_unknown_parameter_name,
600602
ParamLabel);
601603
}
@@ -616,7 +618,8 @@ bool Parser::parseSpecializeAttributeArguments(
616618
return false;
617619
}
618620
if ((ParamLabel == "exported" && Exported.hasValue()) ||
619-
(ParamLabel == "kind" && Kind.hasValue())) {
621+
(ParamLabel == "kind" && Kind.hasValue()) ||
622+
(ParamLabel == "spi" && !spiGroups.empty())) {
620623
diagnose(Tok.getLoc(), diag::attr_specialize_parameter_already_defined,
621624
ParamLabel);
622625
}
@@ -673,6 +676,23 @@ bool Parser::parseSpecializeAttributeArguments(
673676
DeclNameFlag::AllowOperators);
674677
}
675678
}
679+
if (ParamLabel == "spiModule") {
680+
if (!parseSILSIPModule(*this)) {
681+
diagnose(Tok.getLoc(), diag::attr_specialize_unknown_parameter_name,
682+
ParamLabel);
683+
return false;
684+
}
685+
}
686+
if (ParamLabel == "spi") {
687+
if (!Tok.is(tok::identifier)) {
688+
diagnose(Tok.getLoc(), diag::attr_specialize_expected_spi_name);
689+
consumeToken();
690+
return false;
691+
}
692+
auto text = Tok.getText();
693+
spiGroups.push_back(Context.getIdentifier(text));
694+
consumeToken();
695+
}
676696
if (!consumeIf(tok::comma)) {
677697
diagnose(Tok.getLoc(), diag::attr_specialize_missing_comma);
678698
skipUntil(tok::comma, tok::kw_where);
@@ -712,7 +732,8 @@ bool Parser::parseSpecializeAttributeArguments(
712732
bool Parser::parseSpecializeAttribute(
713733
swift::tok ClosingBrace, SourceLoc AtLoc, SourceLoc Loc,
714734
SpecializeAttr *&Attr,
715-
llvm::function_ref<bool(Parser &)> parseSILTargetName) {
735+
llvm::function_ref<bool(Parser &)> parseSILTargetName,
736+
llvm::function_ref<bool(Parser &)> parseSILSIPModule) {
716737
assert(ClosingBrace == tok::r_paren || ClosingBrace == tok::r_square);
717738

718739
SourceLoc lParenLoc = consumeToken();
@@ -725,9 +746,10 @@ bool Parser::parseSpecializeAttribute(
725746
TrailingWhereClause *trailingWhereClause = nullptr;
726747

727748
DeclNameRef targetFunction;
728-
if (!parseSpecializeAttributeArguments(ClosingBrace, DiscardAttribute,
729-
exported, kind, trailingWhereClause,
730-
targetFunction, parseSILTargetName)) {
749+
SmallVector<Identifier, 4> spiGroups;
750+
if (!parseSpecializeAttributeArguments(
751+
ClosingBrace, DiscardAttribute, exported, kind, trailingWhereClause,
752+
targetFunction, spiGroups, parseSILTargetName, parseSILSIPModule)) {
731753
return false;
732754
}
733755

@@ -756,7 +778,7 @@ bool Parser::parseSpecializeAttribute(
756778
// Store the attribute.
757779
Attr = SpecializeAttr::create(Context, AtLoc, SourceRange(Loc, rParenLoc),
758780
trailingWhereClause, exported.getValue(),
759-
kind.getValue(), targetFunction);
781+
kind.getValue(), targetFunction, spiGroups);
760782
return true;
761783
}
762784

0 commit comments

Comments
 (0)