Skip to content

Commit 0ada2e2

Browse files
authored
Merge pull request swiftlang#69029 from jkshtj/main
[Autodiff] Modify inliner logic to award inlining benefits to linear …
2 parents 051bf4d + f973aa1 commit 0ada2e2

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

lib/SILOptimizer/Transforms/PerformanceInliner.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ class SILPerformanceInliner {
208208
llvm::detail::DenseMapPair<swift::SILBasicBlock *, uint64_t>, true>
209209
&bbIt);
210210

211+
bool isAutoDiffLinearMapWithControlFlow(FullApplySite AI);
212+
211213
bool isProfitableToInline(
212214
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
213215
int &NumCallerBlocks,
@@ -299,6 +301,69 @@ bool SILPerformanceInliner::profileBasedDecision(
299301
return true;
300302
}
301303

304+
// Checks if `FAI` can be traced back to a specifically named,
305+
// input enum function argument. If so, the callsite
306+
// containing function is a linear map in Swift Autodiff.
307+
bool SILPerformanceInliner::isAutoDiffLinearMapWithControlFlow(
308+
FullApplySite FAI) {
309+
static const std::string LinearMapBranchTracingEnumPrefix = "_AD__";
310+
311+
auto val = FAI.getCallee();
312+
313+
for (;;) {
314+
if (auto *inst = dyn_cast<SingleValueInstruction>(val)) {
315+
if (auto pi = Projection::isObjectProjection(val)) {
316+
// Extract a member from a struct/tuple/enum.
317+
val = pi->getOperand(0);
318+
continue;
319+
} else if (auto ti = dyn_cast<ThinToThickFunctionInst>(inst)) {
320+
val = ti->getOperand();
321+
continue;
322+
} else if (auto cfi = dyn_cast<ConvertFunctionInst>(inst)) {
323+
val = cfi->getOperand();
324+
continue;
325+
} else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(inst)) {
326+
val = cvt->getOperand();
327+
continue;
328+
}
329+
return false;
330+
} else if (auto *phiArg = dyn_cast<SILPhiArgument>(val)) {
331+
if (auto *predBB = phiArg->getParent()->getSinglePredecessorBlock()) {
332+
// The terminator of this predecessor block must either be a
333+
// (conditional) branch instruction or a switch_enum.
334+
if (auto *bi = dyn_cast<BranchInst>(predBB->getTerminator())) {
335+
val = bi->getArg(phiArg->getIndex());
336+
continue;
337+
} else if (auto *cbi =
338+
dyn_cast<CondBranchInst>(predBB->getTerminator())) {
339+
val = cbi->getArgForDestBB(phiArg->getParent(), phiArg->getIndex());
340+
continue;
341+
} else if (auto *sei =
342+
dyn_cast<SwitchEnumInst>(predBB->getTerminator())) {
343+
val = sei->getOperand();
344+
continue;
345+
}
346+
return false;
347+
}
348+
}
349+
break;
350+
}
351+
352+
// If `val` now points to a function argument then we have successfully traced
353+
// the callee back to a function argument.
354+
//
355+
// We now need to check if this argument is an enum and named like an autodiff
356+
// branch tracing enum.
357+
if (auto *arg = dyn_cast<SILFunctionArgument>(val)) {
358+
if (auto *enumDecl = arg->getType().getEnumOrBoundGenericEnum()) {
359+
return enumDecl->getName().str().startswith(
360+
LinearMapBranchTracingEnumPrefix);
361+
}
362+
}
363+
364+
return false;
365+
}
366+
302367
bool SILPerformanceInliner::isProfitableToInline(
303368
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
304369
int &NumCallerBlocks,
@@ -413,6 +478,19 @@ bool SILPerformanceInliner::isProfitableToInline(
413478
SILInstruction *def = constTracker.getDefInCaller(FAI.getCallee());
414479
if (def && (isa<FunctionRefInst>(def) || isa<PartialApplyInst>(def)))
415480
BlockW.updateBenefit(Benefit, RemovedClosureBenefit);
481+
else if (isAutoDiffLinearMapWithControlFlow(FAI)) {
482+
// For linear maps in Swift Autodiff, callees may be passed as an
483+
// argument, however, they may be hidden behind a branch-tracing
484+
// enum (tracing execution flow of the original function).
485+
//
486+
// If we can establish that we are inside of a Swift Autodiff linear
487+
// map and that the branch tracing input enum is wrapping pullback
488+
// closures, then we can update this function's benefit with
489+
// `RemovedClosureBenefit` because inlining will (probably) eliminate
490+
// the closure.
491+
BlockW.updateBenefit(Benefit, RemovedClosureBenefit);
492+
}
493+
416494
// Check if inlining the callee would allow for further
417495
// optimizations like devirtualization or generic specialization.
418496
if (!def)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// VJP and pullback inlining tests.
2+
3+
// RUN: %target-swift-frontend -emit-sil -O -verify -Xllvm -debug-only=sil-inliner %s 2>&1 | %FileCheck %s
4+
5+
// REQUIRES: asserts
6+
// REQUIRES: swift_in_compiler
7+
8+
import _Differentiation
9+
#if os(Linux)
10+
import Glibc
11+
#else
12+
import Foundation
13+
#endif
14+
15+
16+
// ======================== Pullback w/ control-flow ======================== //
17+
18+
@differentiable(reverse)
19+
@_silgen_name("pb_with_control_flow")
20+
func pb_with_control_flow(_ x: Float) -> Float {
21+
if (x > 0) {
22+
return sin(x) * cos(x)
23+
} else {
24+
return sin(x) + cos(x)
25+
}
26+
}
27+
28+
@inline(never)
29+
@_silgen_name("caller_of_pb_with_control_flow")
30+
func caller_of_pb_with_control_flow() -> Float {
31+
gradient(at: Float(1), of: pb_with_control_flow)
32+
}
33+
34+
// CHECK: decision {{{.*}}, b=70, {{.*}}} pb_with_control_flowTJpSpSr
35+
// CHECK-NEXT: "pb_with_control_flowTJpSpSr" inlined into "caller_of_pb_with_control_flow"
36+
37+
38+
@differentiable(reverse)
39+
func double(x: Float) -> Float {
40+
return x + x
41+
}
42+
43+
@differentiable(reverse)
44+
func square(x: Float) -> Float {
45+
return x * x
46+
}
47+
48+
@differentiable(reverse)
49+
@_silgen_name("more_complex_pb_with_control_flow")
50+
func more_complex_pb_with_control_flow(x: Float) -> Float {
51+
if (x > 0) {
52+
if ((x+1) < 5) {
53+
if (x*2 > 4) {
54+
let y = square(x: x)
55+
if (y >= x) {
56+
let d = double(x: x)
57+
return x - (d*y)
58+
} else {
59+
let e = square(x: y)
60+
return x + (e*y)
61+
}
62+
}
63+
}
64+
} else {
65+
let y = double(x: x)
66+
return x * y
67+
}
68+
69+
return x*3
70+
}
71+
72+
@inline(never)
73+
@_silgen_name("caller_of_more_complex_pb_with_control_flow")
74+
func caller_of_more_complex_pb_with_control_flow() -> Float {
75+
gradient(at: Float(1), of: more_complex_pb_with_control_flow)
76+
}
77+
78+
// CHECK: decision {{{.*}}, b=70, {{.*}}} more_complex_pb_with_control_flowTJpSpSr
79+
// CHECK-NEXT: "more_complex_pb_with_control_flowTJpSpSr" inlined into "caller_of_more_complex_pb_with_control_flow"

0 commit comments

Comments
 (0)