Skip to content

Commit fcfacdd

Browse files
authored
[AutoDiff] Register VJPs for SIMD subscript(_: Int) setters. (swiftlang#32747)
1 parent ab5e810 commit fcfacdd

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ where
4545
extension SIMD${n}
4646
where
4747
Scalar: Differentiable & BinaryFloatingPoint,
48-
Scalar.TangentVector: BinaryFloatingPoint
48+
Scalar.TangentVector == Scalar
4949
{
5050
// NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
5151
@inlinable
5252
@derivative(of: subscript(_:))
53-
internal func _vjpSubscript(index: Int)
53+
internal func _vjpSubscript(_ index: Int)
5454
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector)
5555
{
5656
return (self[index], { v in
5757
var zeros = Self.zero
58-
zeros[index] = Scalar(v)
58+
zeros[index] = v
5959
return zeros
6060
})
6161
}
@@ -69,6 +69,19 @@ where
6969
return .init(v[index])
7070
})
7171
}
72+
73+
@inlinable
74+
@derivative(of: subscript(_:).set)
75+
internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int)
76+
-> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector)
77+
{
78+
self[index] = newValue
79+
return ((), { dSelf in
80+
let dNewValue = dSelf[index]
81+
dSelf[index] = 0
82+
return dNewValue
83+
})
84+
}
7285
}
7386

7487
%end
@@ -421,9 +434,9 @@ where
421434
@inlinable
422435
@derivative(of: sum)
423436
func _jvpSum() -> (
424-
value: Scalar, differential (TangentVector) -> Scalar.TangentVector
437+
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
425438
) {
426-
return (sum(), { v in v.sum() }
439+
return (sum(), { v in Scalar.TangentVector(v.sum()) })
427440
}
428441
}
429442
*/

test/AutoDiff/stdlib/simd.swift

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ SIMDTests.test("Negate") {
5656
expectEqual(-g, pb1(g))
5757
}
5858

59-
SIMDTests.test("subscript") {
59+
SIMDTests.test("Subscript") {
6060
let a = SIMD4<Float>(1, 2, 3, 4)
6161

6262
func foo1(x: SIMD4<Float>) -> Float {
@@ -68,6 +68,35 @@ SIMDTests.test("subscript") {
6868
expectEqual(SIMD4<Float>(0, 0, 0, 7), pb1(7))
6969
}
7070

71+
SIMDTests.test("SubscriptSetter") {
72+
let a = SIMD4<Float>(1, 2, 3, 4)
73+
let ones = SIMD4<Float>(1, 1, 1, 1)
74+
75+
// A wrapper around `subscript(_: Int).set`.
76+
func subscriptSet(
77+
_ simd: SIMD4<Float>, index: Int, newScalar: Float
78+
) -> SIMD4<Float> {
79+
var result = simd
80+
result[index] = newScalar
81+
return result
82+
}
83+
84+
let (val1, pb1) = valueWithPullback(at: a, 5, in: { subscriptSet($0, index: 2, newScalar: $1) })
85+
expectEqual(SIMD4<Float>(1, 2, 5, 4), val1)
86+
expectEqual((SIMD4<Float>(1, 1, 0, 1), 1), pb1(ones))
87+
88+
func doubled(_ x: SIMD4<Float>) -> SIMD4<Float> {
89+
var result = x
90+
for i in withoutDerivative(at: x.indices) {
91+
result[i] = x[i] * 2
92+
}
93+
return result
94+
}
95+
let (val2, pb2) = valueWithPullback(at: a, in: doubled)
96+
expectEqual(SIMD4<Float>(2, 4, 6, 8), val2)
97+
expectEqual(SIMD4<Float>(2, 2, 2, 2), pb2(ones))
98+
}
99+
71100
SIMDTests.test("Addition") {
72101
let a = SIMD4<Float>(1, 2, 3, 4)
73102
let g = SIMD4<Float>(1, 1, 1, 1)

0 commit comments

Comments
 (0)