Skip to content

Commit a864a57

Browse files
authored
[AutoDiff upstream] Add Differentiable.withDerivative(_:). (swiftlang#30945)
Add `Differentiable.withDerivative(_:)`, a "derivative surgery" API. `Differentiable.withDerivative(_:)` is an identity function returning `self`. It takes a closure and applies it to the derivative of the return value, in contexts where the return value is differentiated with respect to.
1 parent cabb657 commit a864a57

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

stdlib/public/Differentiation/DifferentiationUtilities.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
import Swift
1919

20+
//===----------------------------------------------------------------------===//
21+
// Differentiable function creation
22+
//===----------------------------------------------------------------------===//
23+
2024
/// Create a differentiable function from a vector-Jacobian products function.
2125
@inlinable
2226
public func differentiableFunction<T : Differentiable, R : Differentiable>(
@@ -70,6 +74,10 @@ public func differentiableFunction<T, U, V, R>(
7074
/*vjp*/ vjp)
7175
}
7276

77+
//===----------------------------------------------------------------------===//
78+
// Derivative customization
79+
//===----------------------------------------------------------------------===//
80+
7381
/// Returns `x` like an identity function. When used in a context where `x` is
7482
/// being differentiated with respect to, this function will not produce any
7583
/// derivative at `x`.
@@ -91,6 +99,31 @@ public func withoutDerivative<T, R>(at x: T, in body: (T) -> R) -> R {
9199
body(x)
92100
}
93101

102+
public extension Differentiable {
103+
/// Applies the given closure to the derivative of `self`.
104+
///
105+
/// Returns `self` like an identity function. When the return value is used in
106+
/// a context where it is differentiated with respect to, applies the given
107+
/// closure to the derivative of the return value.
108+
@inlinable
109+
@differentiable(wrt: self)
110+
func withDerivative(_ body: @escaping (inout TangentVector) -> Void) -> Self {
111+
return self
112+
}
113+
114+
@inlinable
115+
@derivative(of: withDerivative)
116+
internal func _vjpWithDerivative(
117+
_ body: @escaping (inout TangentVector) -> Void
118+
) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
119+
return (self, { grad in
120+
var grad = grad
121+
body(&grad)
122+
return grad
123+
})
124+
}
125+
}
126+
94127
//===----------------------------------------------------------------------===//
95128
// Diagnostics
96129
//===----------------------------------------------------------------------===//
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import DifferentiationUnittest
5+
import StdlibUnittest
6+
7+
var DerivativeCustomizationTests = TestSuite("DerivativeCustomization")
8+
9+
DerivativeCustomizationTests.testWithLeakChecking("withDerivative") {
10+
do {
11+
var counter = 0
12+
func callback(_ x: inout Tracked<Float>) { counter += 1 }
13+
_ = gradient(at: 4) { (x: Tracked<Float>) -> Tracked<Float> in
14+
// Non-active value should not be differentiated, so `callback` should
15+
// not be called.
16+
_ = x.withDerivative(callback)
17+
return x.withDerivative(callback) + x.withDerivative(callback)
18+
}
19+
expectEqual(2, counter)
20+
}
21+
22+
expectEqual(
23+
30,
24+
gradient(at: 4) { (x: Tracked<Float>) in
25+
x.withDerivative { $0 = 10 } + x.withDerivative { $0 = 20 }
26+
})
27+
}
28+
29+
DerivativeCustomizationTests.testWithLeakChecking("withoutDerivative") {
30+
expectEqual(
31+
0,
32+
gradient(at: Tracked<Float>(4)) { x -> Tracked<Float> in
33+
withoutDerivative(at: x) { x in
34+
x * x * x
35+
}
36+
})
37+
38+
expectEqual(
39+
0,
40+
gradient(at: Tracked<Float>(4)) { x -> Tracked<Float> in
41+
let y = withoutDerivative(at: x)
42+
return y * y * y
43+
})
44+
45+
expectEqual(
46+
2,
47+
gradient(at: Tracked<Float>(4)) { x -> Tracked<Float> in
48+
let y = withoutDerivative(at: x)
49+
return x + y * y * y + x
50+
})
51+
}
52+
53+
runAllTests()

0 commit comments

Comments
 (0)