Skip to content

Commit e0b84da

Browse files
authored
Merge pull request swiftlang#36527 from vojtamolda/main
2 parents ae35835 + 9697dd8 commit e0b84da

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ where Element: Differentiable {
7373
Array<Element.TangentVector>.DifferentiableView
7474

7575
public mutating func move(by offset: TangentVector) {
76+
if offset.base.isEmpty {
77+
return
78+
}
7679
precondition(
7780
base.count == offset.base.count, """
7881
Count mismatch: \(base.count) ('self') and \(offset.base.count) \
7982
('direction')
8083
""")
81-
for i in base.indices {
84+
for i in offset.base.indices {
8285
base[i].move(by: offset.base[i])
8386
}
8487
}
@@ -217,10 +220,13 @@ extension Array where Element: Differentiable {
217220
pullback: (TangentVector) -> (TangentVector, TangentVector)
218221
) {
219222
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
223+
if v.base.isEmpty {
224+
return (.zero, .zero)
225+
}
220226
precondition(
221227
v.base.count == lhs.count + rhs.count, """
222-
Tangent vector with invalid count; expected to equal the sum of \
223-
operand counts \(lhs.count) and \(rhs.count)
228+
Tangent vector with invalid count \(v.base.count); expected to \
229+
equal the sum of operand counts \(lhs.count) and \(rhs.count)
224230
""")
225231
return (
226232
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),

test/AutoDiff/validation-test/array.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ ArrayAutoDiffTests.test("Array.+") {
344344
}
345345
let v = FloatArrayTan([4, -5, 6])
346346
expectEqual(v, pullback(at: [1, 2, 3], of: identity)(v))
347+
348+
let v1: [Float] = [1, 1]
349+
let v2: [Float] = [1, 1, 1]
350+
expectEqual((.zero, .zero), pullback(at: v1, v2, of: +)(.zero))
347351
}
348352

349353
ArrayAutoDiffTests.test("Array.+=") {
@@ -438,4 +442,14 @@ ArrayAutoDiffTests.test("Array.DifferentiableView.base") {
438442
backprop(FloatArrayTan([1, 2, 3, 4])))
439443
}
440444

445+
ArrayAutoDiffTests.test("Array.DifferentiableView.move") {
446+
var v: [Float] = [1, 2, 3]
447+
v.move(by: .zero)
448+
expectEqual(v, [1, 2, 3])
449+
450+
var z: [Float] = []
451+
z.move(by: .zero)
452+
expectEqual(z, [])
453+
}
454+
441455
runAllTests()

0 commit comments

Comments
 (0)