Skip to content

Commit 06b2066

Browse files
committed
[AutoDiff] Conform Array's and Optional's TangentVector to CustomReflectable
Resolves rdar://88542240.
1 parent 3708676 commit 06b2066

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ where Element: Differentiable {
110110
}
111111
}
112112

113+
extension Array.DifferentiableView: CustomReflectable {
114+
public var customMirror: Mirror {
115+
return base.customMirror
116+
}
117+
}
118+
113119
/// Makes `Array.DifferentiableView` additive as the product space.
114120
///
115121
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces

stdlib/public/Differentiation/OptionalDifferentiation.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,9 @@ extension Optional: Differentiable where Wrapped: Differentiable {
5757
}
5858
}
5959
}
60+
61+
extension Optional.TangentVector: CustomReflectable {
62+
public var customMirror: Mirror {
63+
return value.customMirror
64+
}
65+
}

test/AutoDiff/stdlib/optional.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,15 @@ OptionalDifferentiationTests.test("Optional.TangentVector operations") {
7676
}
7777
}
7878

79+
OptionalDifferentiationTests.test("Optional.TangentVector reflection") {
80+
let tan = Optional<Float>.TangentVector(42)
81+
let children = Array(Mirror(reflecting: tan).children)
82+
expectEqual(1, children.count)
83+
// We test `==` first because `as?` will flatten optionals.
84+
expectTrue(type(of: children[0].value) == Float.self)
85+
if let child = expectNotNil(children[0].value as? Float) {
86+
expectEqual(42, child)
87+
}
88+
}
89+
7990
runAllTests()

test/AutoDiff/validation-test/array.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,4 +454,15 @@ ArrayAutoDiffTests.test("Array.DifferentiableView.move") {
454454
expectEqual(z, [])
455455
}
456456

457+
ArrayAutoDiffTests.test("Array.DifferentiableView reflection") {
458+
let tan = [Float].DifferentiableView([41, 42])
459+
let children = Array(Mirror(reflecting: tan).children)
460+
expectEqual(2, children.count)
461+
if let child1 = expectNotNil(children[0].value as? Float),
462+
let child2 = expectNotNil(children[1].value as? Float) {
463+
expectEqual(41, child1)
464+
expectEqual(42, child2)
465+
}
466+
}
467+
457468
runAllTests()

0 commit comments

Comments
 (0)