Skip to content

Commit b926c18

Browse files
authored
Fix adjoint generation of store_borrow (#61431)
Apparently #60467 changed the semantics of store_borrow as it started to produce a value. This change was not documented in SIL spec and not all places were updated to new semantics. Now the adjoint of store_borrow should be generated for the value of instruction itself, not the destination address
1 parent 40d1146 commit b926c18

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,7 @@ class PullbackCloner::Implementation final
14951495
}
14961496
void visitStoreBorrowInst(StoreBorrowInst *sbi) {
14971497
visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(),
1498-
sbi->getDest());
1498+
sbi);
14991499
}
15001500

15011501
/// Handle `copy_addr` instruction.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import DifferentiationUnittest
6+
7+
var StoreBorrowAdjTest = TestSuite("StoreBorrowAdjTest")
8+
9+
public struct ConstantTimeAccessor<Element>: Differentiable where Element: Differentiable, Element: AdditiveArithmetic {
10+
public struct TangentVector: Differentiable, AdditiveArithmetic {
11+
public typealias TangentVector = ConstantTimeAccessor.TangentVector
12+
public var _base: [Element.TangentVector]
13+
public var accessed: Element.TangentVector
14+
15+
public init(_base: [Element.TangentVector], accessed: Element.TangentVector) {
16+
self._base = _base
17+
self.accessed = accessed
18+
}
19+
}
20+
21+
@usableFromInline
22+
var _values: [Element]
23+
24+
public var accessed: Element
25+
26+
@inlinable
27+
@differentiable(reverse)
28+
public init(_ values: [Element], accessed: Element = .zero) {
29+
self._values = values
30+
self.accessed = accessed
31+
}
32+
33+
@inlinable
34+
@differentiable(reverse)
35+
public var array: [Element] { return _values }
36+
37+
@noDerivative
38+
public var count: Int { return _values.count }
39+
}
40+
41+
public extension ConstantTimeAccessor {
42+
@inlinable
43+
@derivative(of: init(_:accessed:))
44+
static func _vjpInit(_ values: [Element],
45+
accessed: Element = .zero)
46+
-> (value: ConstantTimeAccessor, pullback: (TangentVector) -> (Array<Element>.TangentVector, Element.TangentVector)) {
47+
return (ConstantTimeAccessor(values, accessed: accessed), { v in
48+
let base: Array<Element>.TangentVector
49+
if v._base.count < values.count {
50+
base = Array<Element>
51+
.TangentVector(v._base + Array<Element.TangentVector>(repeating: .zero, count: values.count - v._base.count))
52+
}
53+
else {
54+
base = Array<Element>.TangentVector(v._base)
55+
}
56+
57+
return (base, v.accessed)
58+
})
59+
}
60+
61+
@inlinable
62+
@derivative(of: array)
63+
func vjpArray() -> (value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
64+
func pullback(v: Array<Element>.TangentVector) -> TangentVector {
65+
var base: [Element.TangentVector]
66+
let localZero = Element.TangentVector.zero
67+
if v.base.allSatisfy({ $0 == localZero }) {
68+
base = []
69+
}
70+
else {
71+
base = v.base
72+
}
73+
return TangentVector(_base: base, accessed: Element.TangentVector.zero)
74+
}
75+
return (_values, pullback)
76+
}
77+
78+
mutating func move(by offset: TangentVector) {
79+
self.accessed.move(by: offset.accessed)
80+
_values.move(by: Array<Element>.TangentVector(offset._base))
81+
}
82+
}
83+
84+
public extension ConstantTimeAccessor.TangentVector {
85+
@inlinable
86+
static func + (lhs: Self, rhs: Self) -> Self {
87+
if rhs._base.isEmpty {
88+
return lhs
89+
}
90+
else if lhs._base.isEmpty {
91+
return rhs
92+
}
93+
else {
94+
var base = zip(lhs._base, rhs._base).map(+)
95+
if lhs._base.count < rhs._base.count {
96+
base.append(contentsOf: rhs._base.suffix(from: lhs._base.count))
97+
}
98+
else if lhs._base.count > rhs._base.count {
99+
base.append(contentsOf: lhs._base.suffix(from: rhs._base.count))
100+
}
101+
102+
return Self(_base: base, accessed: lhs.accessed + rhs.accessed)
103+
}
104+
}
105+
106+
@inlinable
107+
static func - (lhs: Self, rhs: Self) -> Self {
108+
if rhs._base.isEmpty {
109+
return lhs
110+
}
111+
else {
112+
var base = zip(lhs._base, rhs._base).map(-)
113+
if lhs._base.count < rhs._base.count {
114+
base.append(contentsOf: rhs._base.suffix(from: lhs._base.count).map { .zero - $0 })
115+
}
116+
else if lhs._base.count > rhs._base.count {
117+
base.append(contentsOf: lhs._base.suffix(from: rhs._base.count))
118+
}
119+
120+
return Self(_base: base, accessed: lhs.accessed - rhs.accessed)
121+
}
122+
}
123+
124+
@inlinable
125+
static var zero: Self { Self(_base: [], accessed: .zero) }
126+
}
127+
128+
StoreBorrowAdjTest.test("NonZeroGrad") {
129+
@differentiable(reverse)
130+
func testInits(input: [Float]) -> Float {
131+
let internalAccessor = ConstantTimeAccessor(input)
132+
let internalArray = internalAccessor.array
133+
return internalArray[1]
134+
}
135+
136+
let grad = gradient(at: [42.0, 146.0, 73.0], of: testInits)
137+
expectEqual(grad[1], 1.0)
138+
}
139+
140+
runAllTests()

0 commit comments

Comments
 (0)