File tree Expand file tree Collapse file tree 4 files changed +34
-0
lines changed
stdlib/public/Differentiation Expand file tree Collapse file tree 4 files changed +34
-0
lines changed Original file line number Diff line number Diff line change @@ -110,6 +110,12 @@ where Element: Differentiable {
110
110
}
111
111
}
112
112
113
+ extension Array . DifferentiableView : CustomReflectable {
114
+ public var customMirror : Mirror {
115
+ return base. customMirror
116
+ }
117
+ }
118
+
113
119
/// Makes `Array.DifferentiableView` additive as the product space.
114
120
///
115
121
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
Original file line number Diff line number Diff line change @@ -57,3 +57,9 @@ extension Optional: Differentiable where Wrapped: Differentiable {
57
57
}
58
58
}
59
59
}
60
+
61
+ extension Optional . TangentVector : CustomReflectable {
62
+ public var customMirror : Mirror {
63
+ return value. customMirror
64
+ }
65
+ }
Original file line number Diff line number Diff line change @@ -76,4 +76,15 @@ OptionalDifferentiationTests.test("Optional.TangentVector operations") {
76
76
}
77
77
}
78
78
79
+ OptionalDifferentiationTests . test ( " Optional.TangentVector reflection " ) {
80
+ let tan = Optional< Float> . TangentVector( 42 )
81
+ let children = Array ( Mirror ( reflecting: tan) . children)
82
+ expectEqual ( 1 , children. count)
83
+ // We test `==` first because `as?` will flatten optionals.
84
+ expectTrue ( type ( of: children [ 0 ] . value) == Float . self)
85
+ if let child = expectNotNil ( children [ 0 ] . value as? Float ) {
86
+ expectEqual ( 42 , child)
87
+ }
88
+ }
89
+
79
90
runAllTests ( )
Original file line number Diff line number Diff line change @@ -454,4 +454,15 @@ ArrayAutoDiffTests.test("Array.DifferentiableView.move") {
454
454
expectEqual ( z, [ ] )
455
455
}
456
456
457
+ ArrayAutoDiffTests . test ( " Array.DifferentiableView reflection " ) {
458
+ let tan = [ Float ] . DifferentiableView ( [ 41 , 42 ] )
459
+ let children = Array ( Mirror ( reflecting: tan) . children)
460
+ expectEqual ( 2 , children. count)
461
+ if let child1 = expectNotNil ( children [ 0 ] . value as? Float ) ,
462
+ let child2 = expectNotNil ( children [ 1 ] . value as? Float ) {
463
+ expectEqual ( 41 , child1)
464
+ expectEqual ( 42 , child2)
465
+ }
466
+ }
467
+
457
468
runAllTests ( )
You can’t perform that action at this time.
0 commit comments