Skip to content

Commit f5cce47

Browse files
authored
Make autdiff more robust in presence of unreachable blocks (#71356)
Unreachable blocks possess some challenges to autodiff since in reverse pass (pullback generation) we need to execute the function backwards, pushing the values from return BB back to entry block. As a result, unreachable blocks might become reachable from the return BB and this might cause all kind of issues.
1 parent 387156a commit f5cce47

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
143143
decl->setInterfaceType(astCtx.TheRawPointerType);
144144
} else { // Otherwise the payload is the linear map tuple.
145145
auto *linearMapStructTy = getLinearMapTupleType(predBB);
146-
assert(linearMapStructTy && "must have linear map struct type for predecessor BB");
146+
// Do not create entries for unreachable predecessors
147+
if (!linearMapStructTy)
148+
continue;
147149
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
148150
decl->setInterfaceType(
149151
canLinearMapStructTy->hasArchetype()

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2087,7 +2087,10 @@ bool PullbackCloner::Implementation::run() {
20872087
originalBlocks.push_back(BB);
20882088

20892089
for (auto *nextBB : BB->getPredecessorBlocks()) {
2090-
workqueue.pushIfNotVisited(nextBB);
2090+
// If there is no linear map tuple for predecessor BB, then BB is
2091+
// unreachable from function entry. Do not run pullback cloner on it.
2092+
if (getPullbackInfo().getLinearMapTupleType(nextBB))
2093+
workqueue.pushIfNotVisited(nextBB);
20912094
}
20922095
}
20932096
}
@@ -2803,6 +2806,11 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
28032806
SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap;
28042807
SmallDenseMap<SILBasicBlock *, SILBasicBlock *> origPredpullbackSuccBBMap;
28052808
for (auto *predBB : bb->getPredecessorBlocks()) {
2809+
// If there is no linear map tuple for predecessor BB, then BB is
2810+
// unreachable from function entry. There is no branch tracing enum for it
2811+
// as well, so we should not create any branching to it in the pullback.
2812+
if (!getPullbackInfo().getLinearMapTupleType(predBB))
2813+
continue;
28062814
auto *pullbackSuccBB =
28072815
buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap);
28082816
origPredpullbackSuccBBMap[predBB] = pullbackSuccBB;
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
// REQUIRES: swift_in_compiler
3+
4+
// https://github.com/apple/swift/issues/71164
5+
6+
// There are few unreachble blocks created due to mandatory boolean constant
7+
// propagation
8+
// Ensure we do not ceeate linear map types for unreachable BB and that they
9+
// are skipped during VJP and pullback generation
10+
import _Differentiation
11+
12+
struct A: Differentiable {}
13+
14+
struct B: Differentiable {
15+
@differentiable(reverse)
16+
func c(b: B) -> A {
17+
while true {
18+
if true {
19+
break
20+
}
21+
};
22+
23+
return A()
24+
}
25+
26+
@differentiable(reverse)
27+
func d(b: B) -> A {
28+
while true {
29+
if true {
30+
return c(b : b)
31+
}
32+
if true {
33+
break
34+
}
35+
if false {
36+
return c(b : b)
37+
} else {
38+
break
39+
}
40+
};
41+
return A()
42+
}
43+
44+
}

0 commit comments

Comments
 (0)