Skip to content

Commit 603db8c

Browse files
committed
[AutoDiff upstream] Add @differentiable function IRGen.
Lower `@differentiable` and `@differentiable(linear)` functions as structs of function pointers.
1 parent 11551e1 commit 603db8c

File tree

6 files changed

+442
-0
lines changed

6 files changed

+442
-0
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ struct AutoDiffDerivativeFunctionKind {
7575
}
7676
};
7777

78+
/// A component of a SIL `@differentiable` function-typed value.
79+
struct NormalDifferentiableFunctionTypeComponent {
80+
enum innerty : unsigned { Original = 0, JVP = 1, VJP = 2 } rawValue;
81+
82+
NormalDifferentiableFunctionTypeComponent() = default;
83+
NormalDifferentiableFunctionTypeComponent(innerty rawValue)
84+
: rawValue(rawValue) {}
85+
NormalDifferentiableFunctionTypeComponent(
86+
AutoDiffDerivativeFunctionKind kind);
87+
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
88+
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
89+
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
90+
operator innerty() const { return rawValue; }
91+
92+
/// Returns the derivative function kind, if the component is a derivative
93+
/// function.
94+
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
95+
};
96+
97+
/// A component of a SIL `@differentiable(linear)` function-typed value.
98+
struct LinearDifferentiableFunctionTypeComponent {
99+
enum innerty : unsigned {
100+
Original = 0,
101+
Transpose = 1,
102+
} rawValue;
103+
104+
LinearDifferentiableFunctionTypeComponent() = default;
105+
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
106+
: rawValue(rawValue) {}
107+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
108+
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
109+
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
110+
operator innerty() const { return rawValue; }
111+
};
112+
78113
/// A derivative function configuration, uniqued in `ASTContext`.
79114
/// Identifies a specific derivative function given an original function.
80115
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {

lib/AST/AutoDiff.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,50 @@ AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind(
2828
rawValue = *result;
2929
}
3030

31+
NormalDifferentiableFunctionTypeComponent::
32+
NormalDifferentiableFunctionTypeComponent(
33+
AutoDiffDerivativeFunctionKind kind) {
34+
switch (kind) {
35+
case AutoDiffDerivativeFunctionKind::JVP:
36+
rawValue = JVP;
37+
return;
38+
case AutoDiffDerivativeFunctionKind::VJP:
39+
rawValue = VJP;
40+
return;
41+
}
42+
}
43+
44+
NormalDifferentiableFunctionTypeComponent::
45+
NormalDifferentiableFunctionTypeComponent(StringRef string) {
46+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
47+
.Case("original", Original)
48+
.Case("jvp", JVP)
49+
.Case("vjp", VJP);
50+
assert(result && "Invalid string");
51+
rawValue = *result;
52+
}
53+
54+
Optional<AutoDiffDerivativeFunctionKind>
55+
NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const {
56+
switch (rawValue) {
57+
case Original:
58+
return None;
59+
case JVP:
60+
return {AutoDiffDerivativeFunctionKind::JVP};
61+
case VJP:
62+
return {AutoDiffDerivativeFunctionKind::VJP};
63+
}
64+
}
65+
66+
LinearDifferentiableFunctionTypeComponent::
67+
LinearDifferentiableFunctionTypeComponent(StringRef string) {
68+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
69+
.Case("original", Original)
70+
.Case("transpose", Transpose);
71+
assert(result && "Invalid string");
72+
rawValue = *result;
73+
}
74+
3175
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
3276
StringRef string) {
3377
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)

lib/IRGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_swift_host_library(swiftIRGen STATIC
1616
GenControl.cpp
1717
GenCoverage.cpp
1818
GenDecl.cpp
19+
GenDiffFunc.cpp
1920
GenDiffWitness.cpp
2021
GenEnum.cpp
2122
GenExistential.cpp

0 commit comments

Comments
 (0)