Skip to content

Commit f973aa1

Browse files
committed
[Autodiff] Modify inliner logic to award inlining benefits to linear maps w/ control-flow
For linear maps containing control-flow, closures (representing the pullbacks of intermediate values) may be passed as arguments, however, they may be hidden behind a branch-tracing enum (tracing execution flow of the original function). Such linear maps did not use to get inlining benefits as the compiler could not see that the intermediate pullback closures were actually part of the input. This change modifies the inliner logic to correctly award inlining benefits to linear maps containing control-flow, by checking if a "callee" in the linear map actually traces back to an input closure that was received as part of a branch-tracing enum input argument. Fixes #68945
1 parent e2a210b commit f973aa1

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

lib/SILOptimizer/Transforms/PerformanceInliner.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ class SILPerformanceInliner {
208208
llvm::detail::DenseMapPair<swift::SILBasicBlock *, uint64_t>, true>
209209
&bbIt);
210210

211+
bool isAutoDiffLinearMapWithControlFlow(FullApplySite AI);
212+
211213
bool isProfitableToInline(
212214
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
213215
int &NumCallerBlocks,
@@ -299,6 +301,69 @@ bool SILPerformanceInliner::profileBasedDecision(
299301
return true;
300302
}
301303

304+
// Checks if `FAI` can be traced back to a specifically named,
305+
// input enum function argument. If so, the callsite
306+
// containing function is a linear map in Swift Autodiff.
307+
bool SILPerformanceInliner::isAutoDiffLinearMapWithControlFlow(
308+
FullApplySite FAI) {
309+
static const std::string LinearMapBranchTracingEnumPrefix = "_AD__";
310+
311+
auto val = FAI.getCallee();
312+
313+
for (;;) {
314+
if (auto *inst = dyn_cast<SingleValueInstruction>(val)) {
315+
if (auto pi = Projection::isObjectProjection(val)) {
316+
// Extract a member from a struct/tuple/enum.
317+
val = pi->getOperand(0);
318+
continue;
319+
} else if (auto ti = dyn_cast<ThinToThickFunctionInst>(inst)) {
320+
val = ti->getOperand();
321+
continue;
322+
} else if (auto cfi = dyn_cast<ConvertFunctionInst>(inst)) {
323+
val = cfi->getOperand();
324+
continue;
325+
} else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(inst)) {
326+
val = cvt->getOperand();
327+
continue;
328+
}
329+
return false;
330+
} else if (auto *phiArg = dyn_cast<SILPhiArgument>(val)) {
331+
if (auto *predBB = phiArg->getParent()->getSinglePredecessorBlock()) {
332+
// The terminator of this predecessor block must either be a
333+
// (conditional) branch instruction or a switch_enum.
334+
if (auto *bi = dyn_cast<BranchInst>(predBB->getTerminator())) {
335+
val = bi->getArg(phiArg->getIndex());
336+
continue;
337+
} else if (auto *cbi =
338+
dyn_cast<CondBranchInst>(predBB->getTerminator())) {
339+
val = cbi->getArgForDestBB(phiArg->getParent(), phiArg->getIndex());
340+
continue;
341+
} else if (auto *sei =
342+
dyn_cast<SwitchEnumInst>(predBB->getTerminator())) {
343+
val = sei->getOperand();
344+
continue;
345+
}
346+
return false;
347+
}
348+
}
349+
break;
350+
}
351+
352+
// If `val` now points to a function argument then we have successfully traced
353+
// the callee back to a function argument.
354+
//
355+
// We now need to check if this argument is an enum and named like an autodiff
356+
// branch tracing enum.
357+
if (auto *arg = dyn_cast<SILFunctionArgument>(val)) {
358+
if (auto *enumDecl = arg->getType().getEnumOrBoundGenericEnum()) {
359+
return enumDecl->getName().str().startswith(
360+
LinearMapBranchTracingEnumPrefix);
361+
}
362+
}
363+
364+
return false;
365+
}
366+
302367
bool SILPerformanceInliner::isProfitableToInline(
303368
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
304369
int &NumCallerBlocks,
@@ -413,6 +478,19 @@ bool SILPerformanceInliner::isProfitableToInline(
413478
SILInstruction *def = constTracker.getDefInCaller(FAI.getCallee());
414479
if (def && (isa<FunctionRefInst>(def) || isa<PartialApplyInst>(def)))
415480
BlockW.updateBenefit(Benefit, RemovedClosureBenefit);
481+
else if (isAutoDiffLinearMapWithControlFlow(FAI)) {
482+
// For linear maps in Swift Autodiff, callees may be passed as an
483+
// argument, however, they may be hidden behind a branch-tracing
484+
// enum (tracing execution flow of the original function).
485+
//
486+
// If we can establish that we are inside of a Swift Autodiff linear
487+
// map and that the branch tracing input enum is wrapping pullback
488+
// closures, then we can update this function's benefit with
489+
// `RemovedClosureBenefit` because inlining will (probably) eliminate
490+
// the closure.
491+
BlockW.updateBenefit(Benefit, RemovedClosureBenefit);
492+
}
493+
416494
// Check if inlining the callee would allow for further
417495
// optimizations like devirtualization or generic specialization.
418496
if (!def)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// VJP and pullback inlining tests.
2+
3+
// RUN: %target-swift-frontend -emit-sil -O -verify -Xllvm -debug-only=sil-inliner %s 2>&1 | %FileCheck %s
4+
5+
// REQUIRES: asserts
6+
// REQUIRES: swift_in_compiler
7+
8+
import _Differentiation
9+
#if os(Linux)
10+
import Glibc
11+
#else
12+
import Foundation
13+
#endif
14+
15+
16+
// ======================== Pullback w/ control-flow ======================== //
17+
18+
@differentiable(reverse)
19+
@_silgen_name("pb_with_control_flow")
20+
func pb_with_control_flow(_ x: Float) -> Float {
21+
if (x > 0) {
22+
return sin(x) * cos(x)
23+
} else {
24+
return sin(x) + cos(x)
25+
}
26+
}
27+
28+
@inline(never)
29+
@_silgen_name("caller_of_pb_with_control_flow")
30+
func caller_of_pb_with_control_flow() -> Float {
31+
gradient(at: Float(1), of: pb_with_control_flow)
32+
}
33+
34+
// CHECK: decision {{{.*}}, b=70, {{.*}}} pb_with_control_flowTJpSpSr
35+
// CHECK-NEXT: "pb_with_control_flowTJpSpSr" inlined into "caller_of_pb_with_control_flow"
36+
37+
38+
@differentiable(reverse)
39+
func double(x: Float) -> Float {
40+
return x + x
41+
}
42+
43+
@differentiable(reverse)
44+
func square(x: Float) -> Float {
45+
return x * x
46+
}
47+
48+
@differentiable(reverse)
49+
@_silgen_name("more_complex_pb_with_control_flow")
50+
func more_complex_pb_with_control_flow(x: Float) -> Float {
51+
if (x > 0) {
52+
if ((x+1) < 5) {
53+
if (x*2 > 4) {
54+
let y = square(x: x)
55+
if (y >= x) {
56+
let d = double(x: x)
57+
return x - (d*y)
58+
} else {
59+
let e = square(x: y)
60+
return x + (e*y)
61+
}
62+
}
63+
}
64+
} else {
65+
let y = double(x: x)
66+
return x * y
67+
}
68+
69+
return x*3
70+
}
71+
72+
@inline(never)
73+
@_silgen_name("caller_of_more_complex_pb_with_control_flow")
74+
func caller_of_more_complex_pb_with_control_flow() -> Float {
75+
gradient(at: Float(1), of: more_complex_pb_with_control_flow)
76+
}
77+
78+
// CHECK: decision {{{.*}}, b=70, {{.*}}} more_complex_pb_with_control_flowTJpSpSr
79+
// CHECK-NEXT: "more_complex_pb_with_control_flowTJpSpSr" inlined into "caller_of_more_complex_pb_with_control_flow"

0 commit comments

Comments
 (0)