Skip to content

Commit 9697dd8

Browse files
committed
[AutoDiff] Add .zero tangent vector handling to Array.DifferentiableView.+ and .move
Implement correct zero tangent vector, i.e. empty array [], handling during the backward pass of Array.DifferentiableView.+ and move methods. The precondition is no longer triggered by .zero/empty arrays. Fixes SR-14297
1 parent 57516a9 commit 9697dd8

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ where Element: Differentiable {
7373
Array<Element.TangentVector>.DifferentiableView
7474

7575
public mutating func move(by offset: TangentVector) {
76+
if offset.base.isEmpty {
77+
return
78+
}
7679
precondition(
7780
base.count == offset.base.count, """
7881
Count mismatch: \(base.count) ('self') and \(offset.base.count) \
7982
('direction')
8083
""")
81-
for i in base.indices {
84+
for i in offset.base.indices {
8285
base[i].move(by: offset.base[i])
8386
}
8487
}
@@ -217,10 +220,13 @@ extension Array where Element: Differentiable {
217220
pullback: (TangentVector) -> (TangentVector, TangentVector)
218221
) {
219222
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
223+
if v.base.isEmpty {
224+
return (.zero, .zero)
225+
}
220226
precondition(
221227
v.base.count == lhs.count + rhs.count, """
222-
Tangent vector with invalid count; expected to equal the sum of \
223-
operand counts \(lhs.count) and \(rhs.count)
228+
Tangent vector with invalid count \(v.base.count); expected to \
229+
equal the sum of operand counts \(lhs.count) and \(rhs.count)
224230
""")
225231
return (
226232
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),

0 commit comments

Comments
 (0)