Skip to content

Commit a9c9f26

Browse files
authored
Merge pull request swiftlang#27545 from apple/transposing-attr
[AutoDiff upstream] Add `@transpose` attribute.
2 parents 899cc20 + fa31c75 commit a9c9f26

19 files changed

+823
-85
lines changed

include/swift/AST/Attr.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,11 @@ DECL_ATTR(_implicitly_synthesizes_nested_requirement, ImplicitlySynthesizesNeste
540540
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
541541
98)
542542

543+
DECL_ATTR(transpose, Transpose,
544+
OnFunc | LongAttribute | AllowMultipleAttributes |
545+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
546+
99)
547+
543548
#undef TYPE_ATTR
544549
#undef DECL_ATTR_ALIAS
545550
#undef CONTEXTUAL_DECL_ATTR_ALIAS

include/swift/AST/Attr.h

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,19 +1767,19 @@ class DifferentiableAttr final
17671767
}
17681768
};
17691769

1770-
/// The `@derivative` attribute registers a function as a derivative of another
1771-
/// function-like declaration: a 'func', 'init', 'subscript', or 'var' computed
1772-
/// property declaration.
1770+
/// The `@derivative(of:)` attribute registers a function as a derivative of
1771+
/// another function-like declaration: a 'func', 'init', 'subscript', or 'var'
1772+
/// computed property declaration.
17731773
///
1774-
/// The `@derivative` attribute also has an optional `wrt:` clause specifying
1775-
/// the parameters that are differentiated "with respect to", i.e. the
1776-
/// differentiation parameters. The differentiation parameters must conform to
1777-
/// the `Differentiable` protocol.
1774+
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
1775+
/// specifying the parameters that are differentiated "with respect to", i.e.
1776+
/// the differentiation parameters. The differentiation parameters must conform
1777+
/// to the `Differentiable` protocol.
17781778
///
17791779
/// If the `wrt:` clause is unspecified, the differentiation parameters are
17801780
/// inferred to be all parameters that conform to `Differentiable`.
17811781
///
1782-
/// `@derivative` attribute type-checking verifies that the type of the
1782+
/// `@derivative(of:)` attribute type-checking verifies that the type of the
17831783
/// derivative function declaration is consistent with the type of the
17841784
/// referenced original declaration and the differentiation parameters.
17851785
///
@@ -1860,6 +1860,89 @@ class DerivativeAttr final
18601860
}
18611861
};
18621862

1863+
/// The `@transpose(of:)` attribute registers a function as a transpose of
1864+
/// another function-like declaration: a 'func', 'init', 'subscript', or 'var'
1865+
/// computed property declaration.
1866+
///
1867+
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
1868+
/// parameters that are transposed "with respect to", i.e. the transposed
1869+
/// parameters.
1870+
///
1871+
/// Examples:
1872+
/// @transpose(of: foo)
1873+
/// @transpose(of: +, wrt: (0, 1))
1874+
class TransposeAttr final
1875+
: public DeclAttribute,
1876+
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
1877+
friend TrailingObjects;
1878+
1879+
/// The base type of the original function.
1880+
/// This is non-null only when the original function is not top-level (i.e. it
1881+
/// is an instance/static method).
1882+
TypeRepr *BaseTypeRepr;
1883+
/// The original function name.
1884+
DeclNameRefWithLoc OriginalFunctionName;
1885+
/// The original function declaration, resolved by the type checker.
1886+
AbstractFunctionDecl *OriginalFunction = nullptr;
1887+
/// The number of parsed parameters specified in 'wrt:'.
1888+
unsigned NumParsedParameters = 0;
1889+
/// The transposed parameters' indices, resolved by the type checker.
1890+
IndexSubset *ParameterIndices = nullptr;
1891+
1892+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1893+
TypeRepr *baseType, DeclNameRefWithLoc original,
1894+
ArrayRef<ParsedAutoDiffParameter> params);
1895+
1896+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1897+
TypeRepr *baseType, DeclNameRefWithLoc original,
1898+
IndexSubset *indices);
1899+
1900+
public:
1901+
static TransposeAttr *create(ASTContext &context, bool implicit,
1902+
SourceLoc atLoc, SourceRange baseRange,
1903+
TypeRepr *baseType, DeclNameRefWithLoc original,
1904+
ArrayRef<ParsedAutoDiffParameter> params);
1905+
1906+
static TransposeAttr *create(ASTContext &context, bool implicit,
1907+
SourceLoc atLoc, SourceRange baseRange,
1908+
TypeRepr *baseType, DeclNameRefWithLoc original,
1909+
IndexSubset *indices);
1910+
1911+
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
1912+
DeclNameRefWithLoc getOriginalFunctionName() const {
1913+
return OriginalFunctionName;
1914+
}
1915+
AbstractFunctionDecl *getOriginalFunction() const {
1916+
return OriginalFunction;
1917+
}
1918+
void setOriginalFunction(AbstractFunctionDecl *decl) {
1919+
OriginalFunction = decl;
1920+
}
1921+
1922+
/// The parsed transposed parameters, i.e. the list of parameters specified in
1923+
/// 'wrt:'.
1924+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1925+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1926+
}
1927+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1928+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1929+
}
1930+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1931+
return NumParsedParameters;
1932+
}
1933+
1934+
IndexSubset *getParameterIndices() const {
1935+
return ParameterIndices;
1936+
}
1937+
void setParameterIndices(IndexSubset *parameterIndices) {
1938+
ParameterIndices = parameterIndices;
1939+
}
1940+
1941+
static bool classof(const DeclAttribute *DA) {
1942+
return DA->getKind() == DAK_Transpose;
1943+
}
1944+
};
1945+
18631946
/// Attributes that may be applied to declarations.
18641947
class DeclAttributes {
18651948
/// Linked list of declaration attributes.

include/swift/AST/DiagnosticsParse.def

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,9 +1560,12 @@ ERROR(expected_colon_after_label,PointsToFirstBadToken,
15601560
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15611561
"expected a parameter, which can be a function parameter name, "
15621562
"parameter index, or 'self'", ())
1563+
ERROR(diff_params_clause_expected_parameter_unnamed,PointsToFirstBadToken,
1564+
"expected a parameter, which can be a function parameter index or 'self'",
1565+
())
15631566

1564-
// derivative
1565-
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
1567+
// Automatic differentiation attributes
1568+
ERROR(autodiff_attr_expected_original_decl_name,PointsToFirstBadToken,
15661569
"expected an original function name", ())
15671570

15681571
//------------------------------------------------------------------------------

include/swift/Parse/Parser.h

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -994,14 +994,22 @@ class Parser {
994994
Optional<DeclNameRefWithLoc> &vjpSpec,
995995
TrailingWhereClause *&whereClause);
996996

997-
/// Parse a differentiation parameters clause.
997+
/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
998+
/// `@differentiable` and `@derivative` attributes.
999+
/// If `allowNamedParameters` is false, allow only index parameters and
1000+
/// 'self'.
9981001
bool parseDifferentiationParametersClause(
999-
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
1002+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
1003+
bool allowNamedParameters = true);
10001004

10011005
/// Parse the @derivative attribute.
10021006
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
10031007
SourceLoc Loc);
10041008

1009+
/// Parse the @transpose attribute.
1010+
ParserResult<TransposeAttr> parseTransposeAttribute(SourceLoc AtLoc,
1011+
SourceLoc Loc);
1012+
10051013
/// Parse a specific attribute.
10061014
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);
10071015

@@ -1143,7 +1151,19 @@ class Parser {
11431151
SourceLoc &LAngleLoc,
11441152
SourceLoc &RAngleLoc);
11451153

1146-
ParserResult<TypeRepr> parseTypeIdentifier();
1154+
/// Parses a type identifier (e.g. 'Foo' or 'Foo.Bar.Baz').
1155+
///
1156+
/// When `isParsingQualifiedDeclBaseType` is true:
1157+
/// - Parses and returns the base type for a qualified declaration name,
1158+
/// positioning the parser at the '.' before the final declaration name.
1159+
// This position is important for parsing final declaration names like
1160+
// '.init' via `parseUnqualifiedDeclName`.
1161+
/// - For example, 'Foo.Bar.f' parses as 'Foo.Bar' and the parser is
1162+
/// positioned at '.f'.
1163+
/// - If there is no base type qualifier (e.g. when parsing just 'f'), returns
1164+
/// an empty parser error.
1165+
ParserResult<TypeRepr> parseTypeIdentifier(
1166+
bool isParsingQualifiedDeclBaseType = false);
11471167
ParserResult<TypeRepr> parseOldStyleProtocolComposition();
11481168
ParserResult<TypeRepr> parseAnyType();
11491169
ParserResult<TypeRepr> parseSILBoxType(GenericParamList *generics,
@@ -1357,6 +1377,14 @@ class Parser {
13571377
bool canParseAsGenericArgumentList();
13581378

13591379
bool canParseType();
1380+
1381+
/// Returns true if a simple type identifier can be parsed.
1382+
///
1383+
/// \verbatim
1384+
/// simple-type-identifier: identifier generic-argument-list?
1385+
/// \endverbatim
1386+
bool canParseSimpleTypeIdentifier();
1387+
13601388
bool canParseTypeIdentifier();
13611389
bool canParseTypeIdentifierOrTypeComposition();
13621390
bool canParseOldStyleProtocolComposition();
@@ -1366,6 +1394,13 @@ class Parser {
13661394

13671395
bool canParseTypedPattern();
13681396

1397+
/// Returns true if a qualified declaration name base type can be parsed.
1398+
///
1399+
/// \verbatim
1400+
/// qualified-decl-name-base-type: simple-type-identifier '.'
1401+
/// \endverbatim
1402+
bool canParseBaseTypeForQualifiedDeclName();
1403+
13691404
//===--------------------------------------------------------------------===//
13701405
// Expression Parsing
13711406
ParserResult<Expr> parseExpr(Diag<> ID) {

lib/AST/Attr.cpp

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,25 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
366366
Printer.printNewline();
367367
}
368368

369-
// Returns the differentiation parameters clause string for the given function,
370-
// parameter indices, and parsed parameters.
369+
/// Printing style for a differentiation parameter in a `wrt:` differentiation
370+
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
371+
/// `@transpose` attributes.
372+
enum class DifferentiationParameterPrintingStyle {
373+
/// Print parameter by name.
374+
/// Used for `@differentiable` and `@derivative` attribute.
375+
Name,
376+
/// Print parameter by index.
377+
/// Used for `@transpose` attribute.
378+
Index
379+
};
380+
381+
/// Returns the differentiation parameters clause string for the given function,
382+
/// parameter indices, parsed parameters, . Use the parameter indices if
383+
/// specified; otherwise, use the parsed parameters.
371384
static std::string getDifferentiationParametersClauseString(
372385
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
373-
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
386+
ArrayRef<ParsedAutoDiffParameter> parsedParams,
387+
DifferentiationParameterPrintingStyle style) {
374388
assert(function);
375389
bool isInstanceMethod = function->isInstanceMember();
376390
std::string result;
@@ -392,7 +406,14 @@ static std::string getDifferentiationParametersClauseString(
392406
}
393407
// Print remaining differentiation parameters.
394408
interleave(parameters.set_bits(), [&](unsigned index) {
395-
printer << function->getParameters()->get(index)->getName().str();
409+
switch (style) {
410+
case DifferentiationParameterPrintingStyle::Name:
411+
printer << function->getParameters()->get(index)->getName().str();
412+
break;
413+
case DifferentiationParameterPrintingStyle::Index:
414+
printer << index;
415+
break;
416+
}
396417
}, [&] { printer << ", "; });
397418
if (parameterCount > 1)
398419
printer << ')';
@@ -425,11 +446,11 @@ static std::string getDifferentiationParametersClauseString(
425446
return printer.str();
426447
}
427448

428-
// Print the arguments of the given `@differentiable` attribute.
429-
// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
430-
// parameters clause.
431-
// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
432-
// functions.
449+
/// Print the arguments of the given `@differentiable` attribute.
450+
/// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
451+
/// parameters clause.
452+
/// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
453+
/// functions.
433454
static void printDifferentiableAttrArguments(
434455
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
435456
const Decl *D, bool omitWrtClause = false,
@@ -465,7 +486,8 @@ static void printDifferentiableAttrArguments(
465486
// Print differentiation parameters clause, unless it is to be omitted.
466487
if (!omitWrtClause) {
467488
auto diffParamsString = getDifferentiationParametersClauseString(
468-
original, attr->getParameterIndices(), attr->getParsedParameters());
489+
original, attr->getParameterIndices(), attr->getParsedParameters(),
490+
DifferentiationParameterPrintingStyle::Name);
469491
// Check whether differentiation parameter clause is empty.
470492
// Handles edge case where resolved parameter indices are unset and
471493
// parsed parameters are empty. This case should never trigger for
@@ -904,13 +926,29 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
904926
Printer << attr->getOriginalFunctionName().Name;
905927
auto *derivative = cast<AbstractFunctionDecl>(D);
906928
auto diffParamsString = getDifferentiationParametersClauseString(
907-
derivative, attr->getParameterIndices(), attr->getParsedParameters());
929+
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
930+
DifferentiationParameterPrintingStyle::Name);
908931
if (!diffParamsString.empty())
909932
Printer << ", " << diffParamsString;
910933
Printer << ')';
911934
break;
912935
}
913936

937+
case DAK_Transpose: {
938+
Printer.printAttrName("@transpose");
939+
Printer << "(of: ";
940+
auto *attr = cast<TransposeAttr>(this);
941+
Printer << attr->getOriginalFunctionName().Name;
942+
auto *transpose = cast<AbstractFunctionDecl>(D);
943+
auto transParamsString = getDifferentiationParametersClauseString(
944+
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
945+
DifferentiationParameterPrintingStyle::Index);
946+
if (!transParamsString.empty())
947+
Printer << ", " << transParamsString;
948+
Printer << ')';
949+
break;
950+
}
951+
914952
case DAK_ImplicitlySynthesizesNestedRequirement:
915953
Printer.printAttrName("@_implicitly_synthesizes_nested_requirement");
916954
Printer << "(\"" << cast<ImplicitlySynthesizesNestedRequirementAttr>(this)->Value << "\")";
@@ -1054,6 +1092,8 @@ StringRef DeclAttribute::getAttrName() const {
10541092
return "differentiable";
10551093
case DAK_Derivative:
10561094
return "derivative";
1095+
case DAK_Transpose:
1096+
return "transpose";
10571097
}
10581098
llvm_unreachable("bad DeclAttrKind");
10591099
}
@@ -1515,6 +1555,45 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
15151555
std::move(originalName), indices);
15161556
}
15171557

1558+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1559+
SourceRange baseRange, TypeRepr *baseTypeRepr,
1560+
DeclNameRefWithLoc originalName,
1561+
ArrayRef<ParsedAutoDiffParameter> params)
1562+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
1563+
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1564+
NumParsedParameters(params.size()) {
1565+
std::uninitialized_copy(params.begin(), params.end(),
1566+
getTrailingObjects<ParsedAutoDiffParameter>());
1567+
}
1568+
1569+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1570+
SourceRange baseRange, TypeRepr *baseTypeRepr,
1571+
DeclNameRefWithLoc originalName, IndexSubset *indices)
1572+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
1573+
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1574+
ParameterIndices(indices) {}
1575+
1576+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1577+
SourceLoc atLoc, SourceRange baseRange,
1578+
TypeRepr *baseType,
1579+
DeclNameRefWithLoc originalName,
1580+
ArrayRef<ParsedAutoDiffParameter> params) {
1581+
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1582+
void *mem = context.Allocate(size, alignof(TransposeAttr));
1583+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1584+
std::move(originalName), params);
1585+
}
1586+
1587+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1588+
SourceLoc atLoc, SourceRange baseRange,
1589+
TypeRepr *baseType,
1590+
DeclNameRefWithLoc originalName,
1591+
IndexSubset *indices) {
1592+
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
1593+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1594+
std::move(originalName), indices);
1595+
}
1596+
15181597
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
15191598
TypeLoc ProtocolType,
15201599
DeclName MemberName,

0 commit comments

Comments
 (0)