1
+ //===--- ArrayDifferentiation.swift -------------------------*- swift -*-===//
2
+ //
3
+ // This source file is part of the Swift.org open source project
4
+ //
5
+ // Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
6
+ // Licensed under Apache License v2.0 with Runtime Library Exception
7
+ //
8
+ // See https://swift.org/LICENSE.txt for license information
9
+ // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10
+ //
11
+ //===----------------------------------------------------------------------===//
12
+
13
+ import Swift
14
+
15
+ // TODO(TF-938): Add 'Element : Differentiable' constraint.
16
+ extension Array {
17
+ /// The view of an array as the differentiable product manifold of `Element`
18
+ /// multiplied with itself `count` times.
19
+ @frozen
20
+ public struct DifferentiableView {
21
+ var _base : [ Element ]
22
+ }
23
+ }
24
+
25
+ extension Array . DifferentiableView : Differentiable where Element : Differentiable {
26
+ /// The viewed array.
27
+ public var base : [ Element ] {
28
+ get { return _base }
29
+ _modify { yield & _base }
30
+ }
31
+
32
+ @usableFromInline
33
+ @derivative ( of: base)
34
+ func _vjpBase( ) ->
35
+ ( value: [ Element ] , pullback: ( Array < Element > . TangentVector ) -> TangentVector ) {
36
+ return ( base, { $0 } )
37
+ }
38
+
39
+ /// Creates a differentiable view of the given array.
40
+ public init ( _ base: [ Element ] ) { self . _base = base }
41
+
42
+ @usableFromInline
43
+ @derivative ( of: init ( _: ) )
44
+ static func _vjpInit( _ base: [ Element ] ) ->
45
+ ( value: Array . DifferentiableView , pullback: ( TangentVector ) -> TangentVector ) {
46
+ return ( Array . DifferentiableView ( base) , { $0 } )
47
+ }
48
+
49
+ public typealias TangentVector =
50
+ Array < Element . TangentVector > . DifferentiableView
51
+
52
+ public mutating func move( along direction: TangentVector ) {
53
+ precondition (
54
+ base. count == direction. base. count,
55
+ " cannot move Array.DifferentiableView with count \( base. count) along " +
56
+ " direction with different count \( direction. base. count) " )
57
+ for i in base. indices {
58
+ base [ i] . move ( along: direction. base [ i] )
59
+ }
60
+ }
61
+ }
62
+
63
+ extension Array . DifferentiableView : Equatable
64
+ where Element : Differentiable & Equatable {
65
+ public static func == (
66
+ lhs: Array . DifferentiableView ,
67
+ rhs: Array . DifferentiableView
68
+ ) -> Bool {
69
+ return lhs. base == rhs. base
70
+ }
71
+ }
72
+
73
+ extension Array . DifferentiableView : ExpressibleByArrayLiteral
74
+ where Element : Differentiable {
75
+ public init ( arrayLiteral elements: Element ... ) {
76
+ self . init ( elements)
77
+ }
78
+ }
79
+
80
+ extension Array . DifferentiableView : CustomStringConvertible
81
+ where Element : Differentiable {
82
+ public var description : String {
83
+ return base. description
84
+ }
85
+ }
86
+
87
+ /// Makes `Array.DifferentiableView` additive as the product space.
88
+ ///
89
+ /// Note that `Array.DifferentiableView([])` is the zero in the product spaces
90
+ /// of all counts.
91
+ extension Array . DifferentiableView : AdditiveArithmetic
92
+ where Element : AdditiveArithmetic & Differentiable {
93
+
94
+ public static var zero : Array . DifferentiableView {
95
+ return Array . DifferentiableView ( [ ] )
96
+ }
97
+
98
+ public static func + (
99
+ lhs: Array . DifferentiableView ,
100
+ rhs: Array . DifferentiableView
101
+ ) -> Array . DifferentiableView {
102
+ precondition (
103
+ lhs. base. count == 0 || rhs. base. count == 0 ||
104
+ lhs. base. count == rhs. base. count,
105
+ " cannot add Array.DifferentiableViews with different counts: " +
106
+ " \( lhs. base. count) and \( rhs. base. count) " )
107
+ if lhs. base. count == 0 {
108
+ return rhs
109
+ }
110
+ if rhs. base. count == 0 {
111
+ return lhs
112
+ }
113
+ return Array . DifferentiableView ( zip ( lhs. base, rhs. base) . map ( + ) )
114
+ }
115
+
116
+ public static func - (
117
+ lhs: Array . DifferentiableView ,
118
+ rhs: Array . DifferentiableView
119
+ ) -> Array . DifferentiableView {
120
+ precondition (
121
+ lhs. base. count == 0 || rhs. base. count == 0 ||
122
+ lhs. base. count == rhs. base. count,
123
+ " cannot subtract Array.DifferentiableViews with different counts: " +
124
+ " \( lhs. base. count) and \( rhs. base. count) " )
125
+ if lhs. base. count == 0 {
126
+ return rhs
127
+ }
128
+ if rhs. base. count == 0 {
129
+ return lhs
130
+ }
131
+ return Array . DifferentiableView ( zip ( lhs. base, rhs. base) . map ( - ) )
132
+ }
133
+
134
+ @inlinable
135
+ public subscript( _ index: Int ) -> Element {
136
+ if index < base. count {
137
+ return base [ index]
138
+ } else {
139
+ return Element . zero
140
+ }
141
+ }
142
+ }
143
+
144
+ /// Makes `Array` differentiable as the product manifold of `Element`
145
+ /// multiplied with itself `count` times.
146
+ extension Array : Differentiable where Element : Differentiable {
147
+ // In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
148
+ // Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
149
+ // `TangentVector` because `Array` already has a static `+` method with
150
+ // different semantics from `AdditiveArithmetic.+`. So we use
151
+ // `Array.DifferentiableView` for all these associated types.
152
+ public typealias TangentVector =
153
+ Array < Element . TangentVector > . DifferentiableView
154
+
155
+ public mutating func move( along direction: TangentVector ) {
156
+ var view = DifferentiableView ( self )
157
+ view. move ( along: direction)
158
+ self = view. base
159
+ }
160
+
161
+ /// A closure that produces a `TangentVector` of zeros with the same
162
+ /// `count` as `self`.
163
+ public var zeroTangentVectorInitializer : ( ) -> TangentVector {
164
+ { [ count = self . count] in
165
+ TangentVector ( . init( repeating: . zero, count: count) )
166
+ }
167
+ }
168
+ }
169
+
170
+ extension Array where Element : Differentiable {
171
+ @usableFromInline
172
+ @derivative ( of: subscript)
173
+ func _vjpSubscript( index: Int ) ->
174
+ ( value: Element , pullback: ( Element . TangentVector ) -> TangentVector )
175
+ {
176
+ func pullback( _ gradientIn: Element . TangentVector ) -> TangentVector {
177
+ var gradientOut = Array < Element . TangentVector > (
178
+ repeating: . zero,
179
+ count: count)
180
+ gradientOut [ index] = gradientIn
181
+ return TangentVector ( gradientOut)
182
+ }
183
+ return ( self [ index] , pullback)
184
+ }
185
+
186
+ @usableFromInline
187
+ @derivative ( of: + )
188
+ static func _vjpConcatenate( _ lhs: [ Element ] , _ rhs: [ Element ] ) ->
189
+ ( value: [ Element ] , pullback: ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
190
+ func pullback( _ gradientIn: TangentVector ) ->
191
+ ( TangentVector , TangentVector ) {
192
+ precondition (
193
+ gradientIn. base. count == lhs. count + rhs. count,
194
+ " + should receive gradient with count equal to sum of operand " +
195
+ " counts, but counts are: gradient \( gradientIn. base. count) , " +
196
+ " lhs \( lhs. count) , rhs \( rhs. count) " )
197
+ return (
198
+ TangentVector ( Array < Element . TangentVector > (
199
+ gradientIn. base [ 0 ..< lhs. count] ) ) ,
200
+ TangentVector ( Array < Element . TangentVector > (
201
+ gradientIn. base [ lhs. count... ] ) ) )
202
+ }
203
+ return ( lhs + rhs, pullback)
204
+ }
205
+ }
206
+
207
+ extension Array where Element: Differentiable {
208
+ @usableFromInline
209
+ @derivative ( of: append)
210
+ mutating func _vjpAppend( _ element: Element ) -> (
211
+ value: Void , pullback: ( inout TangentVector ) -> Element . TangentVector
212
+ ) {
213
+ let appendedElementIndex = count
214
+ defer { append ( element) }
215
+ return ( ( ) , { dself in dself. base [ appendedElementIndex] } )
216
+ }
217
+
218
+ @usableFromInline
219
+ @derivative ( of: append)
220
+ mutating func _jvpAppend( _ element: Element ) -> (
221
+ value: Void , differential: ( inout TangentVector , Element . TangentVector ) -> Void
222
+ ) {
223
+ append ( element)
224
+ return ( ( ) , { $0. base. append ( $1) } )
225
+ }
226
+ }
227
+
228
+ extension Array where Element: Differentiable {
229
+ @usableFromInline
230
+ @derivative ( of: init ( repeating: count: ) )
231
+ static func _vjpInit( repeating repeatedValue: Element , count: Int ) -> (
232
+ value: Self , pullback: ( TangentVector ) -> Element . TangentVector
233
+ ) {
234
+ ( value: Self ( repeating: repeatedValue, count: count) , pullback: { v in
235
+ v. base. reduce ( . zero, + )
236
+ } )
237
+ }
238
+ }
239
+
240
+ //===----------------------------------------------------------------------===//
241
+ // Differentiable higher order functions for collections
242
+ //===----------------------------------------------------------------------===//
243
+
244
+ public extension Array where Element: Differentiable {
245
+ @differentiable ( wrt: ( self , initialResult) )
246
+ func differentiableReduce< Result: Differentiable > (
247
+ _ initialResult: Result ,
248
+ _ nextPartialResult: @differentiable ( Result , Element ) -> Result
249
+ ) -> Result {
250
+ reduce ( initialResult, nextPartialResult)
251
+ }
252
+
253
+ @usableFromInline
254
+ @derivative ( of: differentiableReduce)
255
+ internal func _vjpDifferentiableReduce< Result: Differentiable > (
256
+ _ initialResult: Result ,
257
+ _ nextPartialResult: @differentiable ( Result , Element ) -> Result
258
+ ) -> ( value: Result ,
259
+ pullback: ( Result . TangentVector )
260
+ -> ( Array . TangentVector , Result . TangentVector ) ) {
261
+ var pullbacks :
262
+ [ ( Result . TangentVector ) -> ( Result . TangentVector , Element . TangentVector ) ]
263
+ = [ ]
264
+ let count = self . count
265
+ pullbacks. reserveCapacity ( count)
266
+ var result = initialResult
267
+ for element in self {
268
+ let ( y, pb) =
269
+ valueWithPullback ( at: result, element, in: nextPartialResult)
270
+ result = y
271
+ pullbacks. append ( pb)
272
+ }
273
+ return ( value: result, pullback: { tangent in
274
+ var resultTangent = tangent
275
+ var elementTangents = TangentVector ( [ ] )
276
+ elementTangents. base. reserveCapacity ( count)
277
+ for pullback in pullbacks. reversed ( ) {
278
+ let ( newResultTangent, elementTangent) = pullback ( resultTangent)
279
+ resultTangent = newResultTangent
280
+ elementTangents. base. append ( elementTangent)
281
+ }
282
+ return ( TangentVector ( elementTangents. base. reversed ( ) ) , resultTangent)
283
+ } )
284
+ }
285
+ }
286
+
287
+ public extension Array where Element: Differentiable {
288
+ @differentiable ( wrt: self )
289
+ func differentiableMap< Result: Differentiable > (
290
+ _ body: @differentiable ( Element ) -> Result
291
+ ) -> [ Result ] {
292
+ map ( body)
293
+ }
294
+
295
+ @usableFromInline
296
+ @derivative ( of: differentiableMap)
297
+ internal func _vjpDifferentiableMap< Result: Differentiable > (
298
+ _ body: @differentiable ( Element ) -> Result
299
+ ) -> ( value: [ Result ] ,
300
+ pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector ) {
301
+ var values : [ Result ] = [ ]
302
+ var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
303
+ for x in self {
304
+ let ( y, pb) = valueWithPullback ( at: x, in: body)
305
+ values. append ( y)
306
+ pullbacks. append ( pb)
307
+ }
308
+ func pullback( _ tans: Array < Result > . TangentVector ) -> Array . TangentVector {
309
+ . init( zip ( tans. base, pullbacks) . map { tan, pb in pb ( tan) } )
310
+ }
311
+ return ( value: values, pullback: pullback)
312
+ }
313
+ }
0 commit comments