|
| 1 | +// RUN: %target-build-swift %s |
| 2 | +// RUN: %target-swift-frontend -c -g -Xllvm -verify-di-holes=true %s |
| 3 | + |
| 4 | +// rdar://74876596 ([SR-14290]: SIL verification fails when differentiating a function of [[Double]]) |
| 5 | + |
| 6 | +import _Differentiation |
| 7 | + |
| 8 | +let values: [[Double]] = [[0, 0], [0, 0]] |
| 9 | +let const = 1.12345 |
| 10 | +let result = add(const, to: values) |
| 11 | + |
| 12 | +@differentiable(reverse) |
| 13 | +func add(_ const: Double, to values: [[Double]]) -> [[Double]] { |
| 14 | + var result = values |
| 15 | + for i in withoutDerivative(at: values.indices) { |
| 16 | + for j in withoutDerivative(at: values.indices) { |
| 17 | + result.updated(at: i, j, with: values[i][j] + const) |
| 18 | + } |
| 19 | + } |
| 20 | + return result |
| 21 | +} |
| 22 | + |
| 23 | +extension Array where Element == [Double] { |
| 24 | + @differentiable(reverse) |
| 25 | + mutating func updated(at i: Int, _ j: Int, with newValue: Double) { |
| 26 | + self[i][j] = newValue |
| 27 | + } |
| 28 | + |
| 29 | + @derivative(of: updated) |
| 30 | + mutating func vjpUpdated(at i: Int, _ j: Int, with newValue: Double) |
| 31 | + -> (value: Void, pullback: (inout TangentVector) -> (Double.TangentVector)) { |
| 32 | + self.updated(at: i, j, with: newValue) |
| 33 | + |
| 34 | + func pullback(dSelf: inout TangentVector) -> (Double.TangentVector) { |
| 35 | + let dElement = dSelf[i][j] |
| 36 | + dSelf.base[i].base[j] = 0 |
| 37 | + return dElement |
| 38 | + } |
| 39 | + let value: Void = () |
| 40 | + |
| 41 | + return (value, pullback) |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | + |
0 commit comments