Skip to content

Commit d52dddc

Browse files
authored
[AutoDiff] Fix adjoint propagation for active bb address arguments (#42393)
Fix adjoint propagation for active bb address arguments: ensure they are accumulated in the proper buffer. Fixes second case of SR-16094
1 parent 75981e7 commit d52dddc

File tree

2 files changed

+139
-3
lines changed

2 files changed

+139
-3
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,10 @@ class PullbackCloner::Implementation final
549549
if (auto adjProj = getAdjointProjection(origBB, originalValue))
550550
return (bufferMap[{origBB, originalValue}] = adjProj);
551551

552+
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for"
553+
<< originalValue
554+
<< "in bb" << origBB->getDebugID() << '\n');
555+
552556
auto bufType = getRemappedTangentType(originalValue->getType());
553557
// Set insertion point for local allocation builder: before the last local
554558
// allocation, or at the start of the pullback function's entry if no local
@@ -583,6 +587,12 @@ class PullbackCloner::Implementation final
583587
assert(originalValue->getFunction() == &getOriginal());
584588
assert(rhsAddress->getFunction() == &getPullback());
585589
auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
590+
591+
LLVM_DEBUG(getADDebugStream() << "Adding"
592+
<< rhsAddress << "to adjoint of "
593+
<< originalValue
594+
<< "in bb" << origBB->getDebugID() << '\n');
595+
586596
builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress);
587597
}
588598

@@ -2340,7 +2350,9 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
23402350
for (auto activeValue : predBBActiveValues) {
23412351
LLVM_DEBUG(getADDebugStream()
23422352
<< "Propagating adjoint of active value " << activeValue
2343-
<< " to predecessors' pullback blocks\n");
2353+
<< "from bb" << origBB->getDebugID()
2354+
<< " to predecessors' (bb" << origPredBB->getDebugID()
2355+
<< ") pullback blocks\n");
23442356
switch (getTangentValueCategory(activeValue)) {
23452357
case SILValueCategory::Object: {
23462358
auto activeValueAdj = getAdjointValue(origBB, activeValue);
@@ -2511,14 +2523,13 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
25112523
case SILValueCategory::Address: {
25122524
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
25132525
for (auto pair : incomingValues) {
2514-
auto *predBB = std::get<0>(pair);
25152526
auto incomingValue = std::get<1>(pair);
25162527
// Handle `switch_enum` on `Optional`.
25172528
auto termInst = bbArg->getSingleTerminator();
25182529
if (isSwitchEnumInstOnOptional(termInst))
25192530
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
25202531
else
2521-
addToAdjointBuffer(predBB, incomingValue, bbArgAdjBuf, pbLoc);
2532+
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
25222533
}
25232534
break;
25242535
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import _Differentiation
5+
import StdlibUnittest
6+
7+
var PullbackTests = TestSuite("Pullback")
8+
9+
extension Dictionary: Differentiable where Value: Differentiable {
10+
public typealias TangentVector = [Key: Value.TangentVector]
11+
public mutating func move(by direction: TangentVector) {
12+
for (componentKey, componentDirection) in direction {
13+
func fatalMissingComponent() -> Value {
14+
fatalError("missing component \(componentKey) in moved Dictionary")
15+
}
16+
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
17+
}
18+
}
19+
20+
public var zeroTangentVectorInitializer: () -> TangentVector {
21+
let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
22+
func initializer() -> Self.TangentVector {
23+
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
24+
}
25+
return initializer
26+
}
27+
}
28+
29+
extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
30+
public static func + (_ lhs: Self, _ rhs: Self) -> Self {
31+
return lhs.merging(rhs, uniquingKeysWith: +)
32+
}
33+
34+
public static func - (_ lhs: Self, _ rhs: Self) -> Self {
35+
return lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
36+
}
37+
38+
public static var zero: Self { [:] }
39+
}
40+
41+
extension Dictionary where Value: Differentiable {
42+
// get
43+
@usableFromInline
44+
@derivative(of: subscript(_:))
45+
func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
46+
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
47+
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
48+
return (self[key], { v in
49+
if let value = v.value {
50+
return [key: value]
51+
}
52+
else {
53+
return .zero
54+
}
55+
})
56+
}
57+
}
58+
59+
public extension Dictionary where Value: Differentiable {
60+
@differentiable(reverse)
61+
mutating func set(_ key: Key, to newValue: Value) {
62+
self[key] = newValue
63+
}
64+
65+
@derivative(of: set)
66+
mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
67+
self.set(key, to: newValue)
68+
69+
let forwardCount = self.count
70+
let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
71+
72+
return ((), { v in
73+
// manual zero tangent initialization
74+
if v.count < forwardCount {
75+
v = Self.TangentVector()
76+
forwardKeys.forEach { v[$0] = .zero }
77+
}
78+
79+
if let dElement = v[key] {
80+
v[key] = .zero
81+
return dElement
82+
}
83+
else { // should this fail?
84+
v[key] = .zero
85+
return .zero
86+
}
87+
})
88+
}
89+
}
90+
91+
92+
PullbackTests.test("ConcreteType") {
93+
func getD(from newValues: [String: Double], at key: String) -> Double? {
94+
if newValues.keys.contains(key) {
95+
return newValues[key]
96+
}
97+
return nil
98+
}
99+
100+
@differentiable(reverse)
101+
func testFunctionD(newValues: [String: Double]) -> Double {
102+
return getD(from: newValues, at: "s1")!
103+
}
104+
105+
expectEqual(pullback(at: ["s1": 1.0], of: testFunctionD)(2), ["s1" : 2.0])
106+
}
107+
108+
PullbackTests.test("GenericType") {
109+
func getG<DataType>(from newValues: [String: DataType], at key: String) -> DataType?
110+
where DataType: Differentiable {
111+
if newValues.keys.contains(key) {
112+
return newValues[key]
113+
}
114+
return nil
115+
}
116+
117+
@differentiable(reverse)
118+
func testFunctionG(newValues: [String: Double]) -> Double {
119+
return getG(from: newValues, at: "s1")!
120+
}
121+
122+
expectEqual(pullback(at: ["s1": 1.0], of: testFunctionG)(2), ["s1" : 2.0])
123+
}
124+
125+
runAllTests()

0 commit comments

Comments
 (0)