1
- //===--- ArrayDifferentiation.swift -------------------------*- swift -*-===//
1
+ //===--- ArrayDifferentiation.swift --------------------------- *- swift -*-===//
2
2
//
3
3
// This source file is part of the Swift.org open source project
4
4
//
5
- // Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
5
+ // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6
6
// Licensed under Apache License v2.0 with Runtime Library Exception
7
7
//
8
8
// See https://swift.org/LICENSE.txt for license information
12
12
13
13
import Swift
14
14
15
- // TODO(TF-938): Add 'Element : Differentiable' constraint.
15
+ //===----------------------------------------------------------------------===//
16
+ // Protocol conformances
17
+ //===----------------------------------------------------------------------===//
18
+
19
+ // TODO(TF-938): Add `Element: Differentiable` requirement.
16
20
extension Array {
17
21
/// The view of an array as the differentiable product manifold of `Element`
18
22
/// multiplied with itself `count` times.
@@ -22,7 +26,8 @@ extension Array {
22
26
}
23
27
}
24
28
25
- extension Array . DifferentiableView : Differentiable where Element : Differentiable {
29
+ extension Array . DifferentiableView : Differentiable
30
+ where Element: Differentiable {
26
31
/// The viewed array.
27
32
public var base : [ Element ] {
28
33
get { return _base }
@@ -31,8 +36,9 @@ extension Array.DifferentiableView : Differentiable where Element : Differentiab
31
36
32
37
@usableFromInline
33
38
@derivative ( of: base)
34
- func _vjpBase( ) ->
35
- ( value: [ Element ] , pullback: ( Array < Element > . TangentVector ) -> TangentVector ) {
39
+ func _vjpBase( ) -> (
40
+ value: [ Element ] , pullback: ( Array < Element > . TangentVector ) -> TangentVector
41
+ ) {
36
42
return ( base, { $0 } )
37
43
}
38
44
@@ -41,8 +47,9 @@ extension Array.DifferentiableView : Differentiable where Element : Differentiab
41
47
42
48
@usableFromInline
43
49
@derivative ( of: init ( _: ) )
44
- static func _vjpInit( _ base: [ Element ] ) ->
45
- ( value: Array . DifferentiableView , pullback: ( TangentVector ) -> TangentVector ) {
50
+ static func _vjpInit( _ base: [ Element ] ) -> (
51
+ value: Array . DifferentiableView , pullback: ( TangentVector ) -> TangentVector
52
+ ) {
46
53
return ( Array . DifferentiableView ( base) , { $0 } )
47
54
}
48
55
@@ -52,16 +59,16 @@ extension Array.DifferentiableView : Differentiable where Element : Differentiab
52
59
public mutating func move( along direction: TangentVector ) {
53
60
precondition (
54
61
base. count == direction. base. count,
55
- " cannot move Array.DifferentiableView with count \( base. count) along " +
56
- " direction with different count \( direction. base. count) " )
62
+ " cannot move Array.DifferentiableView with count \( base. count) along "
63
+ + " direction with different count \( direction. base. count) " )
57
64
for i in base. indices {
58
65
base [ i] . move ( along: direction. base [ i] )
59
66
}
60
67
}
61
68
}
62
69
63
- extension Array . DifferentiableView : Equatable
64
- where Element : Differentiable & Equatable {
70
+ extension Array . DifferentiableView : Equatable
71
+ where Element: Differentiable & Equatable {
65
72
public static func == (
66
73
lhs: Array . DifferentiableView ,
67
74
rhs: Array . DifferentiableView
@@ -70,15 +77,15 @@ extension Array.DifferentiableView : Equatable
70
77
}
71
78
}
72
79
73
- extension Array . DifferentiableView : ExpressibleByArrayLiteral
74
- where Element : Differentiable {
80
+ extension Array . DifferentiableView : ExpressibleByArrayLiteral
81
+ where Element: Differentiable {
75
82
public init ( arrayLiteral elements: Element ... ) {
76
83
self . init ( elements)
77
84
}
78
85
}
79
86
80
- extension Array . DifferentiableView : CustomStringConvertible
81
- where Element : Differentiable {
87
+ extension Array . DifferentiableView : CustomStringConvertible
88
+ where Element: Differentiable {
82
89
public var description : String {
83
90
return base. description
84
91
}
@@ -88,8 +95,8 @@ extension Array.DifferentiableView : CustomStringConvertible
88
95
///
89
96
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
90
97
/// of all counts.
91
- extension Array . DifferentiableView : AdditiveArithmetic
92
- where Element : AdditiveArithmetic & Differentiable {
98
+ extension Array . DifferentiableView : AdditiveArithmetic
99
+ where Element: AdditiveArithmetic & Differentiable {
93
100
94
101
public static var zero : Array . DifferentiableView {
95
102
return Array . DifferentiableView ( [ ] )
@@ -100,10 +107,10 @@ extension Array.DifferentiableView : AdditiveArithmetic
100
107
rhs: Array . DifferentiableView
101
108
) -> Array . DifferentiableView {
102
109
precondition (
103
- lhs. base. count == 0 || rhs. base. count == 0 ||
104
- lhs. base. count == rhs. base. count,
105
- " cannot add Array.DifferentiableViews with different counts: " +
106
- " \( lhs. base. count) and \( rhs. base. count) " )
110
+ lhs. base. count == 0 || rhs. base. count == 0
111
+ || lhs. base. count == rhs. base. count,
112
+ " cannot add Array.DifferentiableViews with different counts: "
113
+ + " \( lhs. base. count) and \( rhs. base. count) " )
107
114
if lhs. base. count == 0 {
108
115
return rhs
109
116
}
@@ -118,10 +125,10 @@ extension Array.DifferentiableView : AdditiveArithmetic
118
125
rhs: Array . DifferentiableView
119
126
) -> Array . DifferentiableView {
120
127
precondition (
121
- lhs. base. count == 0 || rhs. base. count == 0 ||
122
- lhs. base. count == rhs. base. count,
123
- " cannot subtract Array.DifferentiableViews with different counts: " +
124
- " \( lhs. base. count) and \( rhs. base. count) " )
128
+ lhs. base. count == 0 || rhs. base. count == 0
129
+ || lhs. base. count == rhs. base. count,
130
+ " cannot subtract Array.DifferentiableViews with different counts: "
131
+ + " \( lhs. base. count) and \( rhs. base. count) " )
125
132
if lhs. base. count == 0 {
126
133
return rhs
127
134
}
@@ -143,7 +150,7 @@ extension Array.DifferentiableView : AdditiveArithmetic
143
150
144
151
/// Makes `Array` differentiable as the product manifold of `Element`
145
152
/// multiplied with itself `count` times.
146
- extension Array : Differentiable where Element : Differentiable {
153
+ extension Array : Differentiable where Element: Differentiable {
147
154
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
148
155
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
149
156
// `TangentVector` because `Array` already has a static `+` method with
@@ -167,14 +174,18 @@ extension Array : Differentiable where Element : Differentiable {
167
174
}
168
175
}
169
176
170
- extension Array where Element : Differentiable {
177
+ //===----------------------------------------------------------------------===//
178
+ // Derivatives
179
+ //===----------------------------------------------------------------------===//
180
+
181
+ extension Array where Element: Differentiable {
171
182
@usableFromInline
172
183
@derivative ( of: subscript)
173
- func _vjpSubscript( index: Int ) ->
174
- ( value: Element , pullback: ( Element . TangentVector ) -> TangentVector )
175
- {
184
+ func _vjpSubscript( index: Int ) -> (
185
+ value: Element , pullback: ( Element . TangentVector ) -> TangentVector
186
+ ) {
176
187
func pullback( _ gradientIn: Element . TangentVector ) -> TangentVector {
177
- var gradientOut = Array < Element . TangentVector > (
188
+ var gradientOut = [ Element . TangentVector] (
178
189
repeating: . zero,
179
190
count: count)
180
191
gradientOut [ index] = gradientIn
@@ -185,22 +196,27 @@ extension Array where Element : Differentiable {
185
196
186
197
@usableFromInline
187
198
@derivative ( of: + )
188
- static func _vjpConcatenate( _ lhs: [ Element ] , _ rhs: [ Element ] ) ->
189
- ( value: [ Element ] , pullback: ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
190
- func pullback( _ gradientIn: TangentVector ) ->
191
- ( TangentVector , TangentVector ) {
192
- precondition (
193
- gradientIn. base. count == lhs. count + rhs. count,
194
- " + should receive gradient with count equal to sum of operand " +
195
- " counts, but counts are: gradient \( gradientIn. base. count) , " +
196
- " lhs \( lhs. count) , rhs \( rhs. count) " )
197
- return (
198
- TangentVector ( Array < Element . TangentVector > (
199
+ static func _vjpConcatenate( _ lhs: [ Element ] , _ rhs: [ Element ] ) -> (
200
+ value: [ Element ] ,
201
+ pullback: ( TangentVector ) -> ( TangentVector , TangentVector )
202
+ ) {
203
+ func pullback( _ gradientIn: TangentVector ) -> ( TangentVector , TangentVector )
204
+ {
205
+ precondition (
206
+ gradientIn. base. count == lhs. count + rhs. count,
207
+ " + should receive gradient with count equal to sum of operand "
208
+ + " counts, but counts are: gradient \( gradientIn. base. count) , "
209
+ + " lhs \( lhs. count) , rhs \( rhs. count) " )
210
+ return (
211
+ TangentVector (
212
+ [ Element . TangentVector] (
199
213
gradientIn. base [ 0 ..< lhs. count] ) ) ,
200
- TangentVector ( Array < Element . TangentVector > (
201
- gradientIn. base [ lhs. count... ] ) ) )
202
- }
203
- return ( lhs + rhs, pullback)
214
+ TangentVector (
215
+ [ Element . TangentVector] (
216
+ gradientIn. base [ lhs. count... ] ) )
217
+ )
218
+ }
219
+ return ( lhs + rhs, pullback)
204
220
}
205
221
}
206
222
@@ -218,7 +234,8 @@ extension Array where Element: Differentiable {
218
234
@usableFromInline
219
235
@derivative ( of: append)
220
236
mutating func _jvpAppend( _ element: Element ) -> (
221
- value: Void , differential: ( inout TangentVector , Element . TangentVector ) -> Void
237
+ value: Void ,
238
+ differential: ( inout TangentVector , Element . TangentVector ) -> Void
222
239
) {
223
240
append ( element)
224
241
return ( ( ) , { $0. base. append ( $1) } )
@@ -231,19 +248,22 @@ extension Array where Element: Differentiable {
231
248
static func _vjpInit( repeating repeatedValue: Element , count: Int ) -> (
232
249
value: Self , pullback: ( TangentVector ) -> Element . TangentVector
233
250
) {
234
- ( value: Self ( repeating: repeatedValue, count: count) , pullback: { v in
235
- v. base. reduce ( . zero, + )
236
- } )
251
+ (
252
+ value: Self ( repeating: repeatedValue, count: count) ,
253
+ pullback: { v in
254
+ v. base. reduce ( . zero, + )
255
+ }
256
+ )
237
257
}
238
258
}
239
259
240
260
//===----------------------------------------------------------------------===//
241
261
// Differentiable higher order functions for collections
242
262
//===----------------------------------------------------------------------===//
243
263
244
- public extension Array where Element: Differentiable {
264
+ extension Array where Element: Differentiable {
245
265
@differentiable ( wrt: ( self , initialResult) )
246
- func differentiableReduce< Result: Differentiable > (
266
+ public func differentiableReduce< Result: Differentiable > (
247
267
_ initialResult: Result ,
248
268
_ nextPartialResult: @differentiable ( Result , Element ) -> Result
249
269
) -> Result {
@@ -255,12 +275,14 @@ public extension Array where Element: Differentiable {
255
275
internal func _vjpDifferentiableReduce< Result: Differentiable > (
256
276
_ initialResult: Result ,
257
277
_ nextPartialResult: @differentiable ( Result , Element ) -> Result
258
- ) -> ( value: Result ,
259
- pullback: ( Result . TangentVector )
260
- -> ( Array . TangentVector , Result . TangentVector ) ) {
278
+ ) -> (
279
+ value: Result ,
280
+ pullback: ( Result . TangentVector )
281
+ -> ( Array . TangentVector , Result . TangentVector )
282
+ ) {
261
283
var pullbacks :
262
- [ ( Result . TangentVector ) -> ( Result . TangentVector , Element . TangentVector ) ]
263
- = [ ]
284
+ [ ( Result . TangentVector ) -> ( Result . TangentVector , Element . TangentVector ) ] =
285
+ [ ]
264
286
let count = self . count
265
287
pullbacks. reserveCapacity ( count)
266
288
var result = initialResult
@@ -270,23 +292,26 @@ public extension Array where Element: Differentiable {
270
292
result = y
271
293
pullbacks. append ( pb)
272
294
}
273
- return ( value: result, pullback: { tangent in
274
- var resultTangent = tangent
275
- var elementTangents = TangentVector ( [ ] )
276
- elementTangents. base. reserveCapacity ( count)
277
- for pullback in pullbacks. reversed ( ) {
278
- let ( newResultTangent, elementTangent) = pullback ( resultTangent)
279
- resultTangent = newResultTangent
280
- elementTangents. base. append ( elementTangent)
295
+ return (
296
+ value: result,
297
+ pullback: { tangent in
298
+ var resultTangent = tangent
299
+ var elementTangents = TangentVector ( [ ] )
300
+ elementTangents. base. reserveCapacity ( count)
301
+ for pullback in pullbacks. reversed ( ) {
302
+ let ( newResultTangent, elementTangent) = pullback ( resultTangent)
303
+ resultTangent = newResultTangent
304
+ elementTangents. base. append ( elementTangent)
305
+ }
306
+ return ( TangentVector ( elementTangents. base. reversed ( ) ) , resultTangent)
281
307
}
282
- return ( TangentVector ( elementTangents. base. reversed ( ) ) , resultTangent)
283
- } )
308
+ )
284
309
}
285
310
}
286
311
287
- public extension Array where Element: Differentiable {
312
+ extension Array where Element: Differentiable {
288
313
@differentiable ( wrt: self )
289
- func differentiableMap< Result: Differentiable > (
314
+ public func differentiableMap< Result: Differentiable > (
290
315
_ body: @differentiable ( Element ) -> Result
291
316
) -> [ Result ] {
292
317
map ( body)
@@ -296,8 +321,10 @@ public extension Array where Element: Differentiable {
296
321
@derivative ( of: differentiableMap)
297
322
internal func _vjpDifferentiableMap< Result: Differentiable > (
298
323
_ body: @differentiable ( Element ) -> Result
299
- ) -> ( value: [ Result ] ,
300
- pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector ) {
324
+ ) -> (
325
+ value: [ Result ] ,
326
+ pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector
327
+ ) {
301
328
var values : [ Result ] = [ ]
302
329
var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
303
330
for x in self {
@@ -310,4 +337,4 @@ public extension Array where Element: Differentiable {
310
337
}
311
338
return ( value: values, pullback: pullback)
312
339
}
313
- }
340
+ }
0 commit comments