Skip to content

Commit 18fe723

Browse files
authored
Merge pull request swiftlang#35811 from rxwei/69980056-differentiable-reverse
[AutoDiff] Add '@differentiable(reverse)'.
2 parents d3db32d + af8942d commit 18fe723

File tree

149 files changed

+1673
-1535
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+1673
-1535
lines changed

docs/ABI/Mangling.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,8 @@ Types
549549
FUNCTION-KIND ::= 'E' // function type (noescape)
550550
FUNCTION-KIND ::= 'F' // @differentiable function type
551551
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
552-
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
553-
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)
552+
FUNCTION-KIND ::= 'H' // @differentiable(_linear) function type
553+
FUNCTION-KIND ::= 'I' // @differentiable(_linear) function type (escaping)
554554

555555
C-TYPE is mangled according to the Itanium ABI, and prefixed with the length.
556556
Non-ASCII identifiers are preserved as-is; we do not use Punycode.
@@ -632,9 +632,10 @@ mangled in to disambiguate.
632632

633633
CALLEE-ESCAPE ::= 'e' // @escaping (inverse of SIL @noescape)
634634

635-
DIFFERENTIABILITY-KIND ::= DIFFERENTIABLE | LINEAR
636-
DIFFERENTIABLE ::= 'd' // @differentiable
637-
LINEAR ::= 'l' // @differentiable(linear)
635+
DIFFERENTIABILITY-KIND ::= 'd' // @differentiable
636+
DIFFERENTIABILITY-KIND ::= 'l' // @differentiable(_linear)
637+
DIFFERENTIABILITY-KIND ::= 'f' // @differentiable(_forward)
638+
DIFFERENTIABILITY-KIND ::= 'r' // @differentiable(reverse)
638639

639640
CALLEE-CONVENTION ::= 'y' // @callee_unowned
640641
CALLEE-CONVENTION ::= 'g' // @callee_guaranteed

docs/SIL.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7002,7 +7002,7 @@ linear_function
70027002
linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T
70037003

70047004
Bundles a function with its transpose function into a
7005-
``@differentiable(linear)`` function.
7005+
``@differentiable(_linear)`` function.
70067006

70077007
``[parameters ...]`` specifies parameter indices that the original function is
70087008
linear with respect to.
@@ -7051,11 +7051,11 @@ linear_function_extract
70517051

70527052
sil-linear-function-extractee ::= 'original' | 'transpose'
70537053

7054-
linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
7055-
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
7054+
linear_function_extract [original] %0 : $@differentiable(_linear) (T) -> T
7055+
linear_function_extract [transpose] %0 : $@differentiable(_linear) (T) -> T
70567056

70577057
Extracts the original function or a transpose function from the given
7058-
``@differentiable(linear)`` function. The extractee is one of the following:
7058+
``@differentiable(_linear)`` function. The extractee is one of the following:
70597059
``[original]`` or ``[transpose]``.
70607060

70617061

include/swift/ABI/MetadataValues.h

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -756,11 +756,13 @@ enum class FunctionMetadataConvention: uint8_t {
756756
};
757757

758758
/// Differentiability kind for function type metadata.
759-
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
759+
/// Duplicates `DifferentiabilityKind` in AST/AutoDiff.h.
760760
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
761-
NonDifferentiable = 0b00,
762-
Normal = 0b01,
763-
Linear = 0b11
761+
NonDifferentiable = 0b00000,
762+
Forward = 0b00001,
763+
Reverse = 0b00010,
764+
Normal = 0b00011,
765+
Linear = 0b10000,
764766
};
765767

766768
/// Flags in a function type metadata record.
@@ -770,16 +772,16 @@ class TargetFunctionTypeFlags {
770772
// one of the flag bits could be used to identify that the rest of
771773
// the flags is going to be stored somewhere else in the metadata.
772774
enum : int_type {
773-
NumParametersMask = 0x0000FFFFU,
774-
ConventionMask = 0x00FF0000U,
775-
ConventionShift = 16U,
776-
ThrowsMask = 0x01000000U,
777-
ParamFlagsMask = 0x02000000U,
778-
EscapingMask = 0x04000000U,
779-
DifferentiableMask = 0x08000000U,
780-
LinearMask = 0x10000000U,
781-
AsyncMask = 0x20000000U,
782-
ConcurrentMask = 0x40000000U,
775+
NumParametersMask = 0x0000FFFFU,
776+
ConventionMask = 0x00FF0000U,
777+
ConventionShift = 16U,
778+
ThrowsMask = 0x01000000U,
779+
ParamFlagsMask = 0x02000000U,
780+
EscapingMask = 0x04000000U,
781+
DifferentiabilityMask = 0x98000000U,
782+
DifferentiabilityShift = 27U,
783+
AsyncMask = 0x20000000U,
784+
ConcurrentMask = 0x40000000U,
783785
};
784786
int_type Data;
785787

@@ -811,13 +813,9 @@ class TargetFunctionTypeFlags {
811813
}
812814

813815
constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
814-
FunctionMetadataDifferentiabilityKind differentiability) const {
815-
return TargetFunctionTypeFlags<int_type>(
816-
(Data & ~DifferentiableMask & ~LinearMask) |
817-
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
818-
? DifferentiableMask : 0) |
819-
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
820-
? LinearMask : 0));
816+
FunctionMetadataDifferentiabilityKind differentiabilityKind) const {
817+
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
818+
| (int_type(differentiabilityKind) << DifferentiabilityShift));
821819
}
822820

823821
constexpr TargetFunctionTypeFlags<int_type>
@@ -860,16 +858,13 @@ class TargetFunctionTypeFlags {
860858
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }
861859

862860
bool isDifferentiable() const {
863-
return getDifferentiabilityKind() >=
864-
FunctionMetadataDifferentiabilityKind::Normal;
861+
return getDifferentiabilityKind() !=
862+
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
865863
}
866864

867865
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
868-
if (bool(Data & DifferentiableMask))
869-
return FunctionMetadataDifferentiabilityKind::Normal;
870-
if (bool(Data & LinearMask))
871-
return FunctionMetadataDifferentiabilityKind::Linear;
872-
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
866+
return FunctionMetadataDifferentiabilityKind(
867+
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
873868
}
874869

875870
int_type getIntValue() const {

include/swift/AST/Attr.h

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ class TypeAttributes {
8585

8686
// Indicates whether the type's '@differentiable' attribute has a 'linear'
8787
// argument.
88-
bool linear = false;
88+
DifferentiabilityKind differentiabilityKind =
89+
DifferentiabilityKind::NonDifferentiable;
8990

9091
// For an opened existential type, the known ID.
9192
Optional<UUID> OpenedID;
@@ -102,14 +103,6 @@ class TypeAttributes {
102103

103104
bool isValid() const { return AtLoc.isValid(); }
104105

105-
bool isLinear() const {
106-
assert(
107-
!linear ||
108-
(linear && has(TAK_differentiable)) &&
109-
"Linear shouldn't have been true if there's no `@differentiable`");
110-
return linear;
111-
}
112-
113106
void clearAttribute(TypeAttrKind A) {
114107
AttrLocs[A] = SourceLoc();
115108
}
@@ -1790,8 +1783,9 @@ class OriginallyDefinedInAttr: public DeclAttribute {
17901783
/// Attribute that marks a function as differentiable.
17911784
///
17921785
/// Examples:
1793-
/// @differentiable(where T : FloatingPoint)
1794-
/// @differentiable(wrt: (self, x, y))
1786+
/// @differentiable(reverse)
1787+
/// @differentiable(reverse, wrt: (self, x, y))
1788+
/// @differentiable(reverse, wrt: (self, x, y) where T : FloatingPoint)
17951789
class DifferentiableAttr final
17961790
: public DeclAttribute,
17971791
private llvm::TrailingObjects<DifferentiableAttr,
@@ -1803,8 +1797,8 @@ class DifferentiableAttr final
18031797
/// May not be a valid declaration for `@differentiable` attributes.
18041798
/// Resolved during parsing and deserialization.
18051799
Decl *OriginalDeclaration = nullptr;
1806-
/// Whether this function is linear (optional).
1807-
bool Linear;
1800+
/// The differentiability kind.
1801+
DifferentiabilityKind DifferentiabilityKind;
18081802
/// The number of parsed differentiability parameters specified in 'wrt:'.
18091803
unsigned NumParsedParameters = 0;
18101804
/// The differentiability parameter indices, resolved by the type checker.
@@ -1830,25 +1824,28 @@ class DifferentiableAttr final
18301824
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;
18311825

18321826
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
1833-
SourceRange baseRange, bool linear,
1827+
SourceRange baseRange,
1828+
enum DifferentiabilityKind diffKind,
18341829
ArrayRef<ParsedAutoDiffParameter> parameters,
18351830
TrailingWhereClause *clause);
18361831

18371832
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
1838-
SourceRange baseRange, bool linear,
1833+
SourceRange baseRange,
1834+
enum DifferentiabilityKind diffKind,
18391835
IndexSubset *parameterIndices,
18401836
GenericSignature derivativeGenericSignature);
18411837

18421838
public:
18431839
static DifferentiableAttr *create(ASTContext &context, bool implicit,
18441840
SourceLoc atLoc, SourceRange baseRange,
1845-
bool linear,
1841+
enum DifferentiabilityKind diffKind,
18461842
ArrayRef<ParsedAutoDiffParameter> params,
18471843
TrailingWhereClause *clause);
18481844

18491845
static DifferentiableAttr *create(AbstractFunctionDecl *original,
18501846
bool implicit, SourceLoc atLoc,
1851-
SourceRange baseRange, bool linear,
1847+
SourceRange baseRange,
1848+
enum DifferentiabilityKind diffKind,
18521849
IndexSubset *parameterIndices,
18531850
GenericSignature derivativeGenSig);
18541851

@@ -1879,7 +1876,25 @@ class DifferentiableAttr final
18791876
return NumParsedParameters;
18801877
}
18811878

1882-
bool isLinear() const { return Linear; }
1879+
enum DifferentiabilityKind getDifferentiabilityKind() const {
1880+
return DifferentiabilityKind;
1881+
}
1882+
1883+
bool isNormalDifferentiability() const {
1884+
return DifferentiabilityKind == DifferentiabilityKind::Normal;
1885+
}
1886+
1887+
bool isLinearDifferentiability() const {
1888+
return DifferentiabilityKind == DifferentiabilityKind::Linear;
1889+
}
1890+
1891+
bool isForwardDifferentiability() const {
1892+
return DifferentiabilityKind == DifferentiabilityKind::Forward;
1893+
}
1894+
1895+
bool isReverseDifferentiability() const {
1896+
return DifferentiabilityKind == DifferentiabilityKind::Reverse;
1897+
}
18831898

18841899
TrailingWhereClause *getWhereClause() const { return WhereClause; }
18851900

include/swift/AST/AutoDiff.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,14 @@ class VarDecl;
4141
/// A function type differentiability kind.
4242
enum class DifferentiabilityKind : uint8_t {
4343
NonDifferentiable = 0,
44-
Normal = 1,
45-
Linear = 2
44+
// '@differentiable(_forward)', rejected by parser.
45+
Forward = 1,
46+
// '@differentiable(reverse)', supported.
47+
Reverse = 2,
48+
// '@differentiable', unsupported.
49+
Normal = 3,
50+
// '@differentiable(_linear)', unsupported.
51+
Linear = 4,
4652
};
4753

4854
/// The kind of an linear map.
@@ -74,9 +80,15 @@ struct AutoDiffDerivativeFunctionKind {
7480
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
7581
explicit AutoDiffDerivativeFunctionKind(StringRef string);
7682
operator innerty() const { return rawValue; }
77-
AutoDiffLinearMapKind getLinearMapKind() {
83+
AutoDiffLinearMapKind getLinearMapKind() const {
7884
return (AutoDiffLinearMapKind::innerty)rawValue;
7985
}
86+
DifferentiabilityKind getMinimalDifferentiabilityKind() const {
87+
switch (rawValue) {
88+
case JVP: return DifferentiabilityKind::Forward;
89+
case VJP: return DifferentiabilityKind::Reverse;
90+
}
91+
}
8092
};
8193

8294
/// A component of a SIL `@differentiable` function-typed value.
@@ -98,7 +110,7 @@ struct NormalDifferentiableFunctionTypeComponent {
98110
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
99111
};
100112

101-
/// A component of a SIL `@differentiable(linear)` function-typed value.
113+
/// A component of a SIL `@differentiable(_linear)` function-typed value.
102114
struct LinearDifferentiableFunctionTypeComponent {
103115
enum innerty : unsigned {
104116
Original = 0,

include/swift/AST/DiagnosticsParse.def

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,14 +1605,19 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
16051605
"expected a member name as second parameter in '_implements' attribute", ())
16061606

16071607
// differentiable
1608+
WARNING(attr_differentiable_expected_reverse,PointsToFirstBadToken,
1609+
"'@differentiable' has been renamed to '@differentiable(reverse)' and "
1610+
"will be removed in the next release", ())
1611+
ERROR(attr_differentiable_kind_not_supported,PointsToFirstBadToken,
1612+
"unsupported differentiability kind '%0'; only 'reverse' is supported", (StringRef))
1613+
ERROR(attr_differentiable_unknown_kind,PointsToFirstBadToken,
1614+
"unknown differentiability kind '%0'; only 'reverse' is supported", (StringRef))
16081615
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
16091616
"expected a list of parameters to differentiate with respect to", ())
16101617
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
16111618
"use 'wrt:' to specify parameters to differentiate with respect to", ())
16121619
ERROR(attr_differentiable_expected_label,none,
16131620
"expected 'wrt:' or 'where' in '@differentiable' attribute", ())
1614-
ERROR(attr_differentiable_unexpected_argument,none,
1615-
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
16161621

16171622
// differentiation `wrt` parameters clause
16181623
ERROR(expected_colon_after_label,PointsToFirstBadToken,

include/swift/AST/DiagnosticsSIL.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ ERROR(autodiff_differentiation_module_not_imported,none,
435435
"Automatic differentiation requires the '_Differentiation' module to be "
436436
"imported", ())
437437
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
438-
"conversion to '@differentiable(linear)' function type is not yet "
438+
"conversion to '@differentiable(_linear)' function type is not yet "
439439
"supported", ())
440440
ERROR(autodiff_function_not_differentiable_error,none,
441441
"function is not differentiable", ())

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,14 +1282,11 @@ ERROR(c_function_pointer_from_method,none,
12821282
ERROR(c_function_pointer_from_generic_function,none,
12831283
"a C function pointer cannot be formed from a reference to a generic "
12841284
"function", ())
1285-
ERROR(unsupported_linear_to_differentiable_conversion,none,
1286-
"conversion from '@differentiable(linear)' to '@differentiable' is not "
1287-
"yet supported", ())
12881285
ERROR(invalid_autoclosure_forwarding,none,
12891286
"add () to forward @autoclosure parameter", ())
12901287
ERROR(invalid_differentiable_function_conversion_expr,none,
1291-
"a '@differentiable%select{|(linear)}0' function can only be formed from "
1292-
"a reference to a 'func' or 'init' or a literal closure", (bool))
1288+
"a '@differentiable' function can only be formed from "
1289+
"a reference to a 'func' or 'init' or a literal closure", ())
12931290
NOTE(invalid_differentiable_function_conversion_parameter,none,
12941291
"did you mean to take a '%0' closure?", (StringRef))
12951292
ERROR(invalid_autoclosure_pointer_conversion,none,
@@ -3109,9 +3106,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
31093106
(Identifier, Identifier))
31103107

31113108
// @differentiable
3112-
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
3113-
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
3114-
"attribute for transpose registration instead", ())
31153109
ERROR(differentiable_attr_overload_not_found,none,
31163110
"%0 does not have expected type %1", (DeclNameRef, Type))
31173111
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
@@ -4551,14 +4545,14 @@ ERROR(attr_only_on_parameters_of_differentiable,none,
45514545
ERROR(differentiable_function_type_invalid_parameter,none,
45524546
"parameter type '%0' does not conform to 'Differentiable'"
45534547
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4554-
"function type is '@differentiable%select{|(linear)}1'"
4548+
"function type is '@differentiable%select{|(_linear)}1'"
45554549
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
45564550
(StringRef, /*isLinear*/ bool,
45574551
/*hasValidDifferentiabilityParameter*/ bool))
45584552
ERROR(differentiable_function_type_invalid_result,none,
45594553
"result type '%0' does not conform to 'Differentiable'"
45604554
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4561-
"function type is '@differentiable%select{|(linear)}1'",
4555+
"function type is '@differentiable%select{|(_linear)}1'",
45624556
(StringRef, bool))
45634557
ERROR(differentiable_function_type_no_differentiability_parameters,
45644558
none,

include/swift/AST/ExtInfo.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class ASTExtInfoBuilder {
293293
// and NumMaskBits must be updated, and they must match.
294294
//
295295
// |representation|noEscape|concurrent|async|throws|differentiability|
296-
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 9 |
296+
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 |
297297
//
298298
enum : unsigned {
299299
RepresentationMask = 0xF << 0,
@@ -302,8 +302,8 @@ class ASTExtInfoBuilder {
302302
AsyncMask = 1 << 6,
303303
ThrowsMask = 1 << 7,
304304
DifferentiabilityMaskOffset = 8,
305-
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
306-
NumMaskBits = 10
305+
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
306+
NumMaskBits = 11
307307
};
308308

309309
unsigned bits; // Naturally sized for speed.
@@ -615,8 +615,8 @@ class SILExtInfoBuilder {
615615
// If bits are added or removed, then TypeBase::SILFunctionTypeBits
616616
// and NumMaskBits must be updated, and they must match.
617617

618-
// |representation|pseudogeneric| noescape | concurrent | async | differentiability|
619-
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 9 |
618+
// |representation|pseudogeneric| noescape | concurrent | async |differentiability|
619+
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 |
620620
//
621621
enum : unsigned {
622622
RepresentationMask = 0xF << 0,
@@ -625,8 +625,8 @@ class SILExtInfoBuilder {
625625
ConcurrentMask = 1 << 6,
626626
AsyncMask = 1 << 7,
627627
DifferentiabilityMaskOffset = 8,
628-
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
629-
NumMaskBits = 10
628+
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
629+
NumMaskBits = 11
630630
};
631631

632632
unsigned bits; // Naturally sized for speed.

0 commit comments

Comments
 (0)