@@ -36,6 +36,40 @@ public protocol Differentiable {
36
36
/// equivalent to exponential map, which moves `self` on the geodesic surface
37
37
/// along the given tangent vector.
38
38
mutating func move( along direction: TangentVector )
39
+
40
+ /// A closure that produces a zero tangent vector, capturing minimal
41
+ /// necessary information from `self`.
42
+ ///
43
+ /// `move(along: zeroTangentVectorInitializer())` should not modify
44
+ /// `self`.
45
+ ///
46
+ /// In some cases, the zero tangent vector of `self` is equal to
47
+ /// `TangentVector.zero`. In other cases, the zero tangent vector depends on
48
+ /// information in `self`, such as shape for an n-dimensional array type.
49
+ /// For differentiable programming, it is more memory-efficient to define a
50
+ /// custom `zeroTangentVectorInitializer` property which returns a closure
51
+ /// that captures and uses only the necessary information to create a zero
52
+ /// tangent vector. For example:
53
+ ///
54
+ /// struct Vector {
55
+ /// var scalars: [Float]
56
+ /// var count: Int { scalars.count }
57
+ /// init(scalars: [Float]) { ... }
58
+ /// init(repeating repeatedElement: Float, count: Int) { ... }
59
+ /// }
60
+ ///
61
+ /// extension Vector: AdditiveArithmetic { ... }
62
+ ///
63
+ /// extension Vector: Differentiable {
64
+ /// typealias TangentVector = Vector
65
+ ///
66
+ /// @noDerivative
67
+ /// var zeroTangentVectorInitializer: () -> TangentVector {
68
+ /// let count = self.count
69
+ /// return { TangentVector(repeating: 0, count: count) }
70
+ /// }
71
+ /// }
72
+ var zeroTangentVectorInitializer : ( ) -> TangentVector { get }
39
73
}
40
74
41
75
public extension Differentiable where TangentVector == Self {
@@ -44,3 +78,24 @@ public extension Differentiable where TangentVector == Self {
44
78
self += direction
45
79
}
46
80
}
81
+
82
+ public extension Differentiable {
83
+ // This is a temporary solution enabling the addition of
84
+ // `zeroTangentVectorInitializer` without implementing derived conformances.
85
+ // This property will produce incorrect results when tangent vectors depend
86
+ // on instance-specific information from `self`.
87
+ // TODO: Implement derived conformances and remove this default
88
+ // implementation.
89
+ @available ( * , deprecated, message: """
90
+ `zeroTangentVectorInitializer` derivation has not been implemented; this \
91
+ default implementation is not correct when tangent vectors depend on \
92
+ instance-specific information from `self` and should not be used
93
+ """ )
94
+ var zeroTangentVectorInitializer : ( ) -> TangentVector {
95
+ { TangentVector . zero }
96
+ }
97
+
98
+ /// A tangent vector initialized using `zeroTangentVectorInitializer`.
99
+ /// `move(along: zeroTangentVector)` should not modify `self`.
100
+ var zeroTangentVector : TangentVector { zeroTangentVectorInitializer ( ) }
101
+ }
0 commit comments