Skip to content

Commit 27dad91

Browse files
authored
[AutoDiff upstream] Upstream @derivative attribute serialization. (swiftlang#28781)
Upstream `@derivative` attribute serialization/deserialization. Test all original declaration kinds and various `wrt:` parameter clauses. Resolves TF-837.
1 parent 6357412 commit 27dad91

File tree

4 files changed

+153
-52
lines changed

4 files changed

+153
-52
lines changed

lib/AST/Attr.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,20 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
897897
break;
898898
}
899899

900+
case DAK_Derivative: {
901+
Printer.printAttrName("@derivative");
902+
Printer << "(of: ";
903+
auto *attr = cast<DerivativeAttr>(this);
904+
Printer << attr->getOriginalFunctionName().Name;
905+
auto *derivative = cast<AbstractFunctionDecl>(D);
906+
auto diffParamsString = getDifferentiationParametersClauseString(
907+
derivative, attr->getParameterIndices(), attr->getParsedParameters());
908+
if (!diffParamsString.empty())
909+
Printer << ", " << diffParamsString;
910+
Printer << ')';
911+
break;
912+
}
913+
900914
case DAK_ImplicitlySynthesizesNestedRequirement:
901915
Printer.printAttrName("@_implicitly_synthesizes_nested_requirement");
902916
Printer << "(\"" << cast<ImplicitlySynthesizesNestedRequirementAttr>(this)->Value << "\")";

lib/Serialization/Deserialization.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,6 +2124,21 @@ getActualReadWriteImplKind(unsigned rawKind) {
21242124
return None;
21252125
}
21262126

2127+
/// Translate from the serialization DifferentiabilityKind enumerators, which
2128+
/// are guaranteed to be stable, to the AST ones.
2129+
static Optional<swift::AutoDiffDerivativeFunctionKind>
2130+
getActualAutoDiffDerivativeFunctionKind(uint8_t raw) {
2131+
switch (serialization::AutoDiffDerivativeFunctionKind(raw)) {
2132+
#define CASE(ID) \
2133+
case serialization::AutoDiffDerivativeFunctionKind::ID: \
2134+
return {swift::AutoDiffDerivativeFunctionKind::ID};
2135+
CASE(JVP)
2136+
CASE(VJP)
2137+
#undef CASE
2138+
}
2139+
return None;
2140+
}
2141+
21272142
void ModuleFile::configureStorage(AbstractStorageDecl *decl,
21282143
uint8_t rawOpaqueReadOwnership,
21292144
uint8_t rawReadImplKind,
@@ -4164,6 +4179,37 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
41644179
break;
41654180
}
41664181

4182+
case decls_block::Derivative_DECL_ATTR: {
4183+
bool isImplicit;
4184+
uint64_t origNameId;
4185+
DeclID origDeclId;
4186+
uint64_t rawDerivativeKind;
4187+
ArrayRef<uint64_t> parameters;
4188+
4189+
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
4190+
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
4191+
parameters);
4192+
4193+
DeclNameRefWithLoc origName{
4194+
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
4195+
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
4196+
auto derivativeKind =
4197+
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
4198+
if (!derivativeKind)
4199+
MF.fatal();
4200+
llvm::SmallBitVector parametersBitVector(parameters.size());
4201+
for (unsigned i : indices(parameters))
4202+
parametersBitVector[i] = parameters[i];
4203+
auto *indices = IndexSubset::get(ctx, parametersBitVector);
4204+
4205+
auto *derivAttr = DerivativeAttr::create(
4206+
ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices);
4207+
derivAttr->setOriginalFunction(origDecl);
4208+
derivAttr->setDerivativeKind(*derivativeKind);
4209+
Attr = derivAttr;
4210+
break;
4211+
}
4212+
41674213
case decls_block::ImplicitlySynthesizesNestedRequirement_DECL_ATTR: {
41684214
serialization::decls_block::ImplicitlySynthesizesNestedRequirementDeclAttrLayout
41694215
::readRecord(scratch);

lib/Serialization/Serialization.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,21 +2349,17 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
23492349
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
23502350
auto derivativeKind =
23512351
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
2352-
auto paramIndices = attr->getParameterIndices();
2353-
// NOTE(TF-837): `@derivative` attribute serialization is blocked by
2354-
// `@derivative` attribute type-checking (TF-829), which resolves
2355-
// parameter indices (`IndexSubset *`).
2356-
if (!paramIndices)
2357-
return;
2358-
assert(paramIndices && "Parameter indices must be resolved");
2352+
auto *parameterIndices = attr->getParameterIndices();
2353+
assert(parameterIndices && "Parameter indices must be resolved");
23592354
SmallVector<bool, 4> indices;
2360-
for (unsigned i : range(paramIndices->getCapacity()))
2361-
indices.push_back(paramIndices->contains(i));
2355+
for (unsigned i : range(parameterIndices->getCapacity()))
2356+
indices.push_back(parameterIndices->contains(i));
23622357
DerivativeDeclAttrLayout::emitRecord(
23632358
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
23642359
origDeclID, derivativeKind, indices);
23652360
return;
23662361
}
2362+
23672363
case DAK_ImplicitlySynthesizesNestedRequirement: {
23682364
auto *theAttr = cast<ImplicitlySynthesizesNestedRequirementAttr>(DA);
23692365
auto abbrCode = S.DeclTypeAbbrCodes[ImplicitlySynthesizesNestedRequirementDeclAttrLayout::Code];
Lines changed: 88 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,108 @@
11
// RUN: %empty-directory(%t)
2-
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
2+
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming %s -emit-module -parse-as-library -o %t
33
// RUN: llvm-bcanalyzer %t/derivative_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
4-
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
4+
// RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
55

66
// BCANALYZER-NOT: UnknownCode
77

8-
// TODO(TF-837): Enable this test.
9-
// Blocked by TF-829: `@derivative` attribute type-checking.
10-
// XFAIL: *
8+
// REQUIRES: differentiable_programming
119

12-
func add(x: Float, y: Float) -> Float {
13-
return x + y
10+
import _Differentiation
11+
12+
// Dummy `Differentiable`-conforming type.
13+
struct S: Differentiable & AdditiveArithmetic {
14+
static var zero: S { S() }
15+
static func + (_: S, _: S) -> S { S() }
16+
static func - (_: S, _: S) -> S { S() }
17+
typealias TangentVector = S
1418
}
15-
// CHECK: @derivative(of: add, wrt: x)
16-
@derivative(of: add, wrt: x)
17-
func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
18-
return (x + y, { $0 })
19+
20+
// Test top-level functions.
21+
22+
func top1(_ x: S) -> S {
23+
x
1924
}
20-
// CHECK: @derivative(of: add, wrt: (x, y))
21-
@derivative(of: add)
22-
func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
23-
return (x + y, { ($0, $0) })
25+
// CHECK: @derivative(of: top1, wrt: x)
26+
@derivative(of: top1, wrt: x)
27+
func derivativeTop1(_ x: S) -> (value: S, differential: (S) -> S) {
28+
(x, { $0 })
2429
}
2530

26-
func generic<T : Numeric>(x: T) -> T {
27-
return x
31+
func top2<T, U>(_ x: T, _ i: Int, _ y: U) -> U {
32+
y
2833
}
29-
// CHECK: @derivative(of: generic, wrt: x)
30-
@derivative(of: generic)
31-
func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
32-
where T : Numeric, T : Differentiable
33-
{
34-
return (x, { v in v })
34+
// CHECK: @derivative(of: top2, wrt: (x, y))
35+
@derivative(of: top2, wrt: (x, y))
36+
func derivativeTop2<T: Differentiable, U: Differentiable>(
37+
_ x: T, _ i: Int, _ y: U
38+
) -> (value: U, differential: (T.TangentVector, U.TangentVector) -> U.TangentVector) {
39+
(y, { (dx, dy) in dy })
3540
}
3641

37-
protocol InstanceMethod : Differentiable {
38-
func foo(_ x: Self) -> Self
39-
func bar<T : Differentiable>(_ x: T) -> Self
42+
// Test instance methods.
43+
44+
extension S {
45+
func instanceMethod(_ x: S) -> S {
46+
self
47+
}
48+
49+
// CHECK: @derivative(of: instanceMethod, wrt: x)
50+
@derivative(of: instanceMethod, wrt: x)
51+
func derivativeInstanceMethodWrtX(_ x: S) -> (value: S, differential: (S) -> S) {
52+
(self, { _ in .zero })
53+
}
54+
55+
// CHECK: @derivative(of: instanceMethod, wrt: self)
56+
@derivative(of: instanceMethod, wrt: self)
57+
func derivativeInstanceMethodWrtSelf(_ x: S) -> (value: S, differential: (S) -> S) {
58+
(self, { $0 })
59+
}
60+
61+
// CHECK: @derivative(of: instanceMethod, wrt: (self, x))
62+
@derivative(of: instanceMethod, wrt: (self, x))
63+
func derivativeInstanceMethodWrtAll(_ x: S) -> (value: S, differential: (S, S) -> S) {
64+
(self, { (dself, dx) in self })
65+
}
4066
}
41-
extension InstanceMethod {
42-
// CHECK: @derivative(of: foo, wrt: (self, x))
43-
@derivative(of: foo)
44-
func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
45-
return (x, { ($0, $0) })
67+
68+
// Test static methods.
69+
70+
extension S {
71+
static func staticMethod(_ x: S) -> S {
72+
x
73+
}
74+
75+
// CHECK: @derivative(of: staticMethod, wrt: x)
76+
@derivative(of: staticMethod, wrt: x)
77+
static func derivativeStaticMethod(_ x: S) -> (value: S, differential: (S) -> S) {
78+
(x, { $0 })
4679
}
80+
}
81+
82+
// Test computed properties.
83+
84+
extension S {
85+
var computedProperty: S {
86+
self
87+
}
88+
89+
// CHECK: @derivative(of: computedProperty, wrt: self)
90+
@derivative(of: computedProperty, wrt: self)
91+
func derivativeProperty() -> (value: S, differential: (S) -> S) {
92+
(self, { $0 })
93+
}
94+
}
95+
96+
// Test subscripts.
4797

48-
// CHECK: @derivative(of: bar, wrt: (self, x))
49-
@derivative(of: bar, wrt: (self, x))
50-
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T) -> TangentVector)
51-
where T == T.TangentVector
52-
{
53-
return (self, { dself, dx in dself })
98+
extension S {
99+
subscript<T: Differentiable>(x: T) -> S {
100+
self
54101
}
55102

56-
// CHECK: @derivative(of: bar, wrt: (self, x))
57-
@derivative(of: bar, wrt: (self, x))
58-
func vjpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T))
59-
where T == T.TangentVector
60-
{
61-
return (self, { v in (v, .zero) })
103+
// CHECK: @derivative(of: subscript, wrt: self)
104+
@derivative(of: subscript(_:), wrt: self)
105+
func derivativeSubscript<T: Differentiable>(x: T) -> (value: S, differential: (S) -> S) {
106+
(self, { $0 })
62107
}
63108
}

0 commit comments

Comments
 (0)