Skip to content

Commit 5accda2

Browse files
bgogulrxwei
authored andcommitted
[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable. (swiftlang#27506)
This PR introduces `@differentiable` attribute to mark functions as differentiable. This PR only contains changes related to parsing the attribute. Type checking and other changes will be added in subsequent patches. See https://github.com/apple/swift/pull/27506/files#diff-f3216f4188fd5ed34e1007e5a9c2490f for examples and tests for the new attribute.
1 parent 7ddface commit 5accda2

File tree

16 files changed

+1196
-2
lines changed

16 files changed

+1196
-2
lines changed

include/swift/AST/Attr.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,12 @@ SIMPLE_DECL_ATTR(_nonEphemeral, NonEphemeral,
502502
ABIStableToAdd | ABIStableToRemove | APIBreakingToAdd | APIStableToRemove,
503503
90)
504504

505+
DECL_ATTR(differentiable, Differentiable,
506+
OnAccessor | OnConstructor | OnFunc | OnVar | OnSubscript | LongAttribute |
507+
AllowMultipleAttributes |
508+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
509+
91)
510+
505511
SIMPLE_DECL_ATTR(IBSegueAction, IBSegueAction,
506512
OnFunc |
507513
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,

include/swift/AST/Attr.h

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "swift/Basic/Version.h"
2828
#include "swift/AST/Identifier.h"
2929
#include "swift/AST/AttrKind.h"
30+
#include "swift/AST/AutoDiff.h"
3031
#include "swift/AST/ConcreteDeclRef.h"
3132
#include "swift/AST/DeclNameLoc.h"
3233
#include "swift/AST/KnownProtocols.h"
@@ -1724,6 +1725,148 @@ class DeclAttributes {
17241725
SourceLoc getStartLoc(bool forModifiers = false) const;
17251726
};
17261727

1728+
/// A declaration name with location.
1729+
struct DeclNameWithLoc {
1730+
DeclName Name;
1731+
DeclNameLoc Loc;
1732+
};
1733+
1734+
/// Attribute that marks a function as differentiable and optionally specifies
1735+
/// custom associated derivative functions: 'jvp' and 'vjp'.
1736+
///
1737+
/// Examples:
1738+
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1739+
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1740+
class DifferentiableAttr final
1741+
: public DeclAttribute,
1742+
private llvm::TrailingObjects<DifferentiableAttr,
1743+
ParsedAutoDiffParameter> {
1744+
friend TrailingObjects;
1745+
1746+
/// Whether this function is linear (optional).
1747+
bool linear;
1748+
/// The number of parsed parameters specified in 'wrt:'.
1749+
unsigned NumParsedParameters = 0;
1750+
/// The JVP function.
1751+
Optional<DeclNameWithLoc> JVP;
1752+
/// The VJP function.
1753+
Optional<DeclNameWithLoc> VJP;
1754+
/// The JVP function (optional), resolved by the type checker if JVP name is
1755+
/// specified.
1756+
FuncDecl *JVPFunction = nullptr;
1757+
/// The VJP function (optional), resolved by the type checker if VJP name is
1758+
/// specified.
1759+
FuncDecl *VJPFunction = nullptr;
1760+
/// The differentiation parameters' indices, resolved by the type checker.
1761+
IndexSubset *ParameterIndices = nullptr;
1762+
/// The trailing where clause (optional).
1763+
TrailingWhereClause *WhereClause = nullptr;
1764+
/// The generic signature for autodiff associated functions. Resolved by the
1765+
/// type checker based on the original function's generic signature and the
1766+
/// attribute's where clause requirements. This is set only if the attribute
1767+
/// has a where clause.
1768+
GenericSignature DerivativeGenericSignature;
1769+
1770+
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1771+
SourceLoc atLoc, SourceRange baseRange,
1772+
bool linear,
1773+
ArrayRef<ParsedAutoDiffParameter> parameters,
1774+
Optional<DeclNameWithLoc> jvp,
1775+
Optional<DeclNameWithLoc> vjp,
1776+
TrailingWhereClause *clause);
1777+
1778+
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1779+
SourceLoc atLoc, SourceRange baseRange,
1780+
bool linear, IndexSubset *indices,
1781+
Optional<DeclNameWithLoc> jvp,
1782+
Optional<DeclNameWithLoc> vjp,
1783+
GenericSignature derivativeGenericSignature);
1784+
1785+
public:
1786+
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1787+
SourceLoc atLoc, SourceRange baseRange,
1788+
bool linear,
1789+
ArrayRef<ParsedAutoDiffParameter> params,
1790+
Optional<DeclNameWithLoc> jvp,
1791+
Optional<DeclNameWithLoc> vjp,
1792+
TrailingWhereClause *clause);
1793+
1794+
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1795+
SourceLoc atLoc, SourceRange baseRange,
1796+
bool linear, IndexSubset *indices,
1797+
Optional<DeclNameWithLoc> jvp,
1798+
Optional<DeclNameWithLoc> vjp,
1799+
GenericSignature derivativeGenSig);
1800+
1801+
/// Get the optional 'jvp:' function name and location.
1802+
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1803+
/// registered JVP.
1804+
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
1805+
1806+
/// Get the optional 'vjp:' function name and location.
1807+
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1808+
/// registered VJP.
1809+
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
1810+
1811+
IndexSubset *getParameterIndices() const {
1812+
return ParameterIndices;
1813+
}
1814+
void setParameterIndices(IndexSubset *pi) {
1815+
ParameterIndices = pi;
1816+
}
1817+
1818+
/// The parsed differentiation parameters, i.e. the list of parameters
1819+
/// specified in 'wrt:'.
1820+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1821+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1822+
}
1823+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1824+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1825+
}
1826+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1827+
return NumParsedParameters;
1828+
}
1829+
1830+
bool isLinear() const { return linear; }
1831+
1832+
TrailingWhereClause *getWhereClause() const { return WhereClause; }
1833+
1834+
GenericSignature getDerivativeGenericSignature() const {
1835+
return DerivativeGenericSignature;
1836+
}
1837+
void setDerivativeGenericSignature(ASTContext &context,
1838+
GenericSignature derivativeGenSig) {
1839+
DerivativeGenericSignature = derivativeGenSig;
1840+
}
1841+
1842+
FuncDecl *getJVPFunction() const { return JVPFunction; }
1843+
void setJVPFunction(FuncDecl *decl);
1844+
FuncDecl *getVJPFunction() const { return VJPFunction; }
1845+
void setVJPFunction(FuncDecl *decl);
1846+
1847+
bool parametersMatch(const DifferentiableAttr &other) const {
1848+
assert(ParameterIndices && other.ParameterIndices);
1849+
return ParameterIndices == other.ParameterIndices;
1850+
}
1851+
1852+
/// Get the derivative generic environment for the given `@differentiable`
1853+
/// attribute and original function.
1854+
GenericEnvironment *
1855+
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
1856+
1857+
// Print the attribute to the given stream.
1858+
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1859+
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1860+
void print(llvm::raw_ostream &OS, const Decl *D,
1861+
bool omitWrtClause = false,
1862+
bool omitAssociatedFunctions = false) const;
1863+
1864+
static bool classof(const DeclAttribute *DA) {
1865+
return DA->getKind() == DAK_Differentiable;
1866+
}
1867+
};
1868+
1869+
17271870
void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
17281871

17291872
inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {

include/swift/AST/AutoDiff.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===--- AutoDiff.h - Swift Automatic Differentiation ---------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// SWIFT_ENABLE_TENSORFLOW
14+
// This file defines AST support for automatic differentiation.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#ifndef SWIFT_AST_AUTODIFF_H
19+
#define SWIFT_AST_AUTODIFF_H
20+
21+
#include "swift/AST/IndexSubset.h"
22+
#include "swift/Basic/Range.h"
23+
24+
namespace swift {
25+
26+
class ParsedAutoDiffParameter {
27+
public:
28+
enum class Kind { Named, Ordered, Self };
29+
30+
private:
31+
SourceLoc loc;
32+
Kind kind;
33+
union Value {
34+
struct { Identifier name; } Named;
35+
struct { unsigned index; } Ordered;
36+
struct {} self;
37+
Value(Identifier name) : Named({name}) {}
38+
Value(unsigned index) : Ordered({index}) {}
39+
Value() {}
40+
} value;
41+
42+
public:
43+
ParsedAutoDiffParameter(SourceLoc loc, Kind kind, Value value)
44+
: loc(loc), kind(kind), value(value) {}
45+
46+
ParsedAutoDiffParameter(SourceLoc loc, Kind kind, unsigned index)
47+
: loc(loc), kind(kind), value(index) {}
48+
49+
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
50+
Identifier name) {
51+
return { loc, Kind::Named, name };
52+
}
53+
54+
static ParsedAutoDiffParameter getOrderedParameter(SourceLoc loc,
55+
unsigned index) {
56+
return { loc, Kind::Ordered, index };
57+
}
58+
59+
static ParsedAutoDiffParameter getSelfParameter(SourceLoc loc) {
60+
return { loc, Kind::Self, {} };
61+
}
62+
63+
Identifier getName() const {
64+
assert(kind == Kind::Named);
65+
return value.Named.name;
66+
}
67+
68+
unsigned getIndex() const {
69+
return value.Ordered.index;
70+
}
71+
72+
Kind getKind() const {
73+
return kind;
74+
}
75+
76+
SourceLoc getLoc() const {
77+
return loc;
78+
}
79+
80+
bool isEqual(const ParsedAutoDiffParameter &other) const {
81+
if (getKind() != other.getKind())
82+
return false;
83+
if (getKind() == Kind::Named)
84+
return getName() == other.getName();
85+
return getKind() == Kind::Self;
86+
}
87+
};
88+
89+
} // end namespace swift
90+
91+
#endif // SWIFT_AST_AUTODIFF_H

include/swift/AST/DiagnosticsParse.def

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,11 +1511,29 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
15111511
"expected a member name as second parameter in '_implements' attribute", ())
15121512

15131513
// differentiable
1514+
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
1515+
"expected a %0 function name", (StringRef))
1516+
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
1517+
"expected a list of parameters to differentiate with respect to", ())
1518+
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
1519+
"use 'wrt:' to specify parameters to differentiate with respect to", ())
1520+
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
1521+
"missing label '%0:' in '@differentiable' attribute", (StringRef))
1522+
ERROR(attr_differentiable_expected_label,none,
1523+
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
1524+
"or 'vjp:'", ())
15141525
ERROR(differentiable_attribute_expected_rparen,none,
15151526
"expected ')' in '@differentiable' attribute", ())
15161527
ERROR(unexpected_argument_differentiable,none,
15171528
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
15181529

1530+
// differentiation `wrt` parameters clause
1531+
ERROR(expected_colon_after_label,PointsToFirstBadToken,
1532+
"expected a colon ':' after '%0'", (StringRef))
1533+
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
1534+
"expected a parameter, which can be a function parameter name, "
1535+
"parameter index, or 'self'", ())
1536+
15191537
//------------------------------------------------------------------------------
15201538
// MARK: Generics parsing diagnostics
15211539
//------------------------------------------------------------------------------

include/swift/Parse/Parser.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,20 @@ class Parser {
711711
/// Check whether the current token starts with '>'.
712712
bool startsWithGreater(Token Tok) { return startsWithSymbol(Tok, '>'); }
713713

714+
/// Returns true if token is an identifier with the given value.
715+
bool isIdentifier(Token Tok, StringRef value) {
716+
return Tok.is(tok::identifier) && Tok.getText() == value;
717+
}
718+
719+
/// Returns true if token is the identifier "wrt".
720+
bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); }
721+
722+
/// Returns true if token is the identifier "jvp".
723+
bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); }
724+
725+
/// Returns true if token is the identifier "vjp".
726+
bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); }
727+
714728
/// Consume the starting '<' of the current token, which may either
715729
/// be a complete '<' token or some kind of operator token starting with '<',
716730
/// e.g., '<>'.
@@ -796,6 +810,12 @@ class Parser {
796810
return parseAnyIdentifier(Result, L, Diagnostic(ID, Args...));
797811
}
798812

813+
/// \brief Parse an unsigned integer and returns it in \p Result. On failure
814+
/// emit the specified error diagnostic, and a note at the specified note
815+
/// location.
816+
bool parseUnsignedInteger(unsigned &Result, SourceLoc &Loc,
817+
const Diagnostic &D);
818+
799819
/// The parser expects that \p K is next token in the input. If so,
800820
/// it is consumed and false is returned.
801821
///
@@ -973,6 +993,20 @@ class Parser {
973993
ParserResult<ImplementsAttr> parseImplementsAttribute(SourceLoc AtLoc,
974994
SourceLoc Loc);
975995

996+
/// Parse the @differentiable attribute.
997+
ParserResult<DifferentiableAttr> parseDifferentiableAttribute(SourceLoc AtLoc,
998+
SourceLoc Loc);
999+
1000+
/// Parse the arguments inside the @differentiable attribute.
1001+
bool parseDifferentiableAttributeArguments(
1002+
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
1003+
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
1004+
TrailingWhereClause *&whereClause);
1005+
1006+
/// Parse a differentiation parameters clause.
1007+
bool parseDifferentiationParametersClause(
1008+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
1009+
9761010
/// Parse a specific attribute.
9771011
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);
9781012

0 commit comments

Comments
 (0)