Skip to content

Commit 65ab642

Browse files
committed
[AutoDiff upstream] Add Differentiable.zeroTangentVector.
Add `Differentiable.zeroTangentVectorInitializer` protocol requirement and `Differentiable.zeroTangentVector` default implementation.
1 parent 790e1a1 commit 65ab642

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ extension Array: Differentiable where Element: Differentiable {
164164
view.move(along: direction)
165165
self = view.base
166166
}
167+
168+
/// A closure that produces a `TangentVector` of zeros with the same
169+
/// `count` as `self`.
170+
public var zeroTangentVectorInitializer: () -> TangentVector {
171+
{ [count = self.count] in
172+
TangentVector(.init(repeating: .zero, count: count))
173+
}
174+
}
167175
}
168176

169177
//===----------------------------------------------------------------------===//

stdlib/public/Differentiation/Differentiable.swift

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,40 @@ public protocol Differentiable {
3636
/// equivalent to exponential map, which moves `self` on the geodesic surface
3737
/// along the given tangent vector.
3838
mutating func move(along direction: TangentVector)
39+
40+
/// A closure that produces a zero tangent vector, capturing minimal
41+
/// necessary information from `self`.
42+
///
43+
/// `move(along: zeroTangentVectorInitializer())` should not modify
44+
/// `self`.
45+
///
46+
/// In some cases, the zero tangent vector of `self` is equal to
47+
/// `TangentVector.zero`. In other cases, the zero tangent vector depends on
48+
/// information in `self`, such as shape for an n-dimensional array type.
49+
/// For differentiable programming, it is more memory-efficient to define a
50+
/// custom `zeroTangentVectorInitializer` property which returns a closure
51+
/// that captures and uses only the necessary information to create a zero
52+
/// tangent vector. For example:
53+
///
54+
/// struct Vector {
55+
/// var scalars: [Float]
56+
/// var count: Int { scalars.count }
57+
/// init(scalars: [Float]) { ... }
58+
/// init(repeating repeatedElement: Float, count: Int) { ... }
59+
/// }
60+
///
61+
/// extension Vector: AdditiveArithmetic { ... }
62+
///
63+
/// extension Vector: Differentiable {
64+
/// typealias TangentVector = Vector
65+
///
66+
/// @noDerivative
67+
/// var zeroTangentVectorInitializer: () -> TangentVector {
68+
/// let count = self.count
69+
/// return { TangentVector(repeating: 0, count: count) }
70+
/// }
71+
/// }
72+
var zeroTangentVectorInitializer: () -> TangentVector { get }
3973
}
4074

4175
public extension Differentiable where TangentVector == Self {
@@ -44,3 +78,24 @@ public extension Differentiable where TangentVector == Self {
4478
self += direction
4579
}
4680
}
81+
82+
public extension Differentiable {
83+
// This is a temporary solution enabling the addition of
84+
// `zeroTangentVectorInitializer` without implementing derived conformances.
85+
// This property will produce incorrect results when tangent vectors depend
86+
// on instance-specific information from `self`.
87+
// TODO: Implement derived conformances and remove this default
88+
// implementation.
89+
@available(*, deprecated, message: """
90+
`zeroTangentVectorInitializer` derivation has not been implemented; this \
91+
default implementation is not correct when tangent vectors depend on \
92+
instance-specific information from `self` and should not be used
93+
""")
94+
var zeroTangentVectorInitializer: () -> TangentVector {
95+
{ TangentVector.zero }
96+
}
97+
98+
/// A tangent vector initialized using `zeroTangentVectorInitializer`.
99+
/// `move(along: zeroTangentVector)` should not modify `self`.
100+
var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() }
101+
}

test/AutoDiff/stdlib/array.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,10 @@ ArrayAutoDiffTests.test("Array.DifferentiableView.base") {
401401
backprop(FloatArrayTan([1, 2, 3, 4])))
402402
}
403403

404-
// TODO: Upstream `Differentiable.zeroTangentVector` and implementations.
405-
/*
406404
ArrayAutoDiffTests.test("Array.zeroTangentVector") {
407405
let count = 10
408406
let array: [Float] = Array((0..<count).map(Float.init))
409407
expectEqual(array.zeroTangentVector.base, Array(repeating: 0, count: count))
410408
}
411-
*/
412409

413410
runAllTests()

0 commit comments

Comments
 (0)