@@ -42,6 +42,14 @@ where Element: Differentiable {
42
42
return ( base, { $0 } )
43
43
}
44
44
45
+ @usableFromInline
46
+ @derivative ( of: base)
47
+ func _jvpBase( ) -> (
48
+ value: [ Element ] , differential: ( Array < Element > . TangentVector ) -> TangentVector
49
+ ) {
50
+ return ( base, { $0 } )
51
+ }
52
+
45
53
/// Creates a differentiable view of the given array.
46
54
public init ( _ base: [ Element ] ) { self . _base = base }
47
55
@@ -53,6 +61,14 @@ where Element: Differentiable {
53
61
return ( Array . DifferentiableView ( base) , { $0 } )
54
62
}
55
63
64
+ @usableFromInline
65
+ @derivative ( of: init ( _: ) )
66
+ static func _jvpInit( _ base: [ Element ] ) -> (
67
+ value: Array . DifferentiableView , differential: ( TangentVector ) -> TangentVector
68
+ ) {
69
+ return ( Array . DifferentiableView ( base) , { $0 } )
70
+ }
71
+
56
72
public typealias TangentVector =
57
73
Array < Element . TangentVector > . DifferentiableView
58
74
@@ -191,6 +207,17 @@ extension Array where Element: Differentiable {
191
207
return ( self [ index] , pullback)
192
208
}
193
209
210
+ @usableFromInline
211
+ @derivative ( of: subscript)
212
+ func _jvpSubscript( index: Int ) -> (
213
+ value: Element , differential: ( TangentVector ) -> Element . TangentVector
214
+ ) {
215
+ func differential( _ v: TangentVector ) -> Element . TangentVector {
216
+ return v [ index]
217
+ }
218
+ return ( self [ index] , differential)
219
+ }
220
+
194
221
@usableFromInline
195
222
@derivative ( of: + )
196
223
static func _vjpConcatenate( _ lhs: Self , _ rhs: Self ) -> (
@@ -210,8 +237,26 @@ extension Array where Element: Differentiable {
210
237
}
211
238
return ( lhs + rhs, pullback)
212
239
}
240
+
241
+ @usableFromInline
242
+ @derivative ( of: + )
243
+ static func _jvpConcatenate( _ lhs: Self , _ rhs: Self ) -> (
244
+ value: Self ,
245
+ differential: ( TangentVector , TangentVector ) -> TangentVector
246
+ ) {
247
+ func differential( _ l: TangentVector , _ r: TangentVector ) -> TangentVector {
248
+ precondition (
249
+ l. base. count == lhs. count && r. base. count == rhs. count, """
250
+ Tangent vectors with invalid count; expected to equal the \
251
+ operand counts \( lhs. count) and \( rhs. count)
252
+ """ )
253
+ return . init( l. base + r. base)
254
+ }
255
+ return ( lhs + rhs, differential)
256
+ }
213
257
}
214
258
259
+
215
260
extension Array where Element: Differentiable {
216
261
@usableFromInline
217
262
@derivative ( of: append)
@@ -277,6 +322,17 @@ extension Array where Element: Differentiable {
277
322
}
278
323
)
279
324
}
325
+
326
+ @usableFromInline
327
+ @derivative ( of: init ( repeating: count: ) )
328
+ static func _jvpInit( repeating repeatedValue: Element , count: Int ) -> (
329
+ value: Self , differential: ( Element . TangentVector ) -> TangentVector
330
+ ) {
331
+ (
332
+ value: Self ( repeating: repeatedValue, count: count) ,
333
+ differential: { v in TangentVector ( . init( repeating: v, count: count) ) }
334
+ )
335
+ }
280
336
}
281
337
282
338
//===----------------------------------------------------------------------===//
@@ -312,6 +368,27 @@ extension Array where Element: Differentiable {
312
368
}
313
369
return ( value: values, pullback: pullback)
314
370
}
371
+
372
+ @inlinable
373
+ @derivative ( of: differentiableMap)
374
+ internal func _jvpDifferentiableMap< Result: Differentiable > (
375
+ _ body: @differentiable ( Element ) -> Result
376
+ ) -> (
377
+ value: [ Result ] ,
378
+ differential: ( Array . TangentVector ) -> Array < Result > . TangentVector
379
+ ) {
380
+ var values : [ Result ] = [ ]
381
+ var differentials : [ ( Element . TangentVector ) -> Result . TangentVector ] = [ ]
382
+ for x in self {
383
+ let ( y, df) = valueWithDifferential ( at: x, in: body)
384
+ values. append ( y)
385
+ differentials. append ( df)
386
+ }
387
+ func differential( _ tans: Array . TangentVector ) -> Array < Result > . TangentVector {
388
+ . init( zip ( tans. base, differentials) . map { tan, df in df ( tan) } )
389
+ }
390
+ return ( value: values, differential: differential)
391
+ }
315
392
}
316
393
317
394
extension Array where Element: Differentiable {
@@ -361,4 +438,33 @@ extension Array where Element: Differentiable {
361
438
}
362
439
)
363
440
}
441
+
442
+ @inlinable
443
+ @derivative ( of: differentiableReduce, wrt: ( self , initialResult) )
444
+ func _jvpDifferentiableReduce< Result: Differentiable > (
445
+ _ initialResult: Result ,
446
+ _ nextPartialResult: @differentiable ( Result , Element ) -> Result
447
+ ) -> ( value: Result ,
448
+ differential: ( Array . TangentVector , Result . TangentVector )
449
+ -> Result . TangentVector ) {
450
+ var differentials :
451
+ [ ( Result . TangentVector , Element . TangentVector ) -> Result . TangentVector ]
452
+ = [ ]
453
+ let count = self . count
454
+ differentials. reserveCapacity ( count)
455
+ var result = initialResult
456
+ for element in self {
457
+ let ( y, df) =
458
+ valueWithDifferential ( at: result, element, in: nextPartialResult)
459
+ result = y
460
+ differentials. append ( df)
461
+ }
462
+ return ( value: result, differential: { dSelf, dInitial in
463
+ var dResult = dInitial
464
+ for (dElement, df) in zip ( dSelf. base, differentials) {
465
+ dResult = df ( dResult, dElement)
466
+ }
467
+ return dResult
468
+ } )
469
+ }
364
470
}
0 commit comments