Skip to content

Commit 12e3a6e

Browse files
committed
[AutoDiff] Modify inlining logic to award inlining benefits to VJPs
Similar to #69029 but for VJPs.
1 parent 7afa4cf commit 12e3a6e

File tree

4 files changed

+83
-17
lines changed

4 files changed

+83
-17
lines changed

include/swift/SILOptimizer/Utils/PerformanceInlinerUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ inline bool isOptimizableSemanticFunction(SILFunction *function) {
7070
/// within another semantic function, or from a "trivial" wrapper.
7171
bool isNestedSemanticCall(FullApplySite apply);
7272

73+
// Strips down simple function conversion operations until a base SILValue is
74+
// reached.
75+
//
76+
// Returns a nullptr if `val` is not a function conversion instruction.
77+
SILValue stripFunctionConversions(SILValue val);
7378
} // end swift namespace
7479

7580
//===----------------------------------------------------------------------===//

lib/SILOptimizer/Transforms/PerformanceInliner.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ class SILPerformanceInliner {
213213

214214
bool isAutoDiffLinearMapWithControlFlow(FullApplySite AI);
215215

216+
bool isTupleWithAllocsOrPartialApplies(SILValue retVal);
217+
216218
bool isProfitableToInline(
217219
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
218220
int &NumCallerBlocks,
@@ -319,14 +321,8 @@ bool SILPerformanceInliner::isAutoDiffLinearMapWithControlFlow(
319321
// Extract a member from a struct/tuple/enum.
320322
val = pi->getOperand(0);
321323
continue;
322-
} else if (auto ti = dyn_cast<ThinToThickFunctionInst>(inst)) {
323-
val = ti->getOperand();
324-
continue;
325-
} else if (auto cfi = dyn_cast<ConvertFunctionInst>(inst)) {
326-
val = cfi->getOperand();
327-
continue;
328-
} else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(inst)) {
329-
val = cvt->getOperand();
324+
} else if (auto base = stripFunctionConversions(inst)) {
325+
val = base;
330326
continue;
331327
}
332328
return false;
@@ -367,6 +363,29 @@ bool SILPerformanceInliner::isAutoDiffLinearMapWithControlFlow(
367363
return false;
368364
}
369365

366+
// Checks if the given value is a tuple containing allocated objects
367+
// or partial applies.
368+
//
369+
// Returns true if the number of allocated objects or partial applies is
370+
// greater than 0, and false otherwise.
371+
//
372+
// Returns false if the value is not a tuple.
373+
bool SILPerformanceInliner::isTupleWithAllocsOrPartialApplies(SILValue val) {
374+
if (auto *ti = dyn_cast<TupleInst>(val)) {
375+
for (auto i : range(ti->getNumOperands())) {
376+
SILValue val = ti->getOperand(i);
377+
378+
if (auto base = stripFunctionConversions(val))
379+
val = base;
380+
381+
if (isa<AllocationInst>(val) || isa<PartialApplyInst>(val))
382+
return true;
383+
}
384+
}
385+
386+
return false;
387+
}
388+
370389
bool SILPerformanceInliner::isProfitableToInline(
371390
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
372391
int &NumCallerBlocks,
@@ -483,6 +502,9 @@ bool SILPerformanceInliner::isProfitableToInline(
483502
if (def && (isa<FunctionRefInst>(def) || isa<PartialApplyInst>(def)))
484503
BlockW.updateBenefit(Benefit, RemovedClosureBenefit);
485504
else if (isAutoDiffLinearMapWithControlFlow(FAI)) {
505+
// TODO: Do we need to tweak inlining benefits given to pullbacks
506+
// (with and without control-flow)?
507+
486508
// For linear maps in Swift Autodiff, callees may be passed as an
487509
// argument, however, they may be hidden behind a branch-tracing
488510
// enum (tracing execution flow of the original function).
@@ -587,7 +609,7 @@ bool SILPerformanceInliner::isProfitableToInline(
587609
// Inlining functions which return an allocated object or partial_apply
588610
// most likely has a benefit in the caller, because e.g. it can enable
589611
// de-virtualization.
590-
if (isa<AllocationInst>(retVal) || isa<PartialApplyInst>(retVal)) {
612+
if (isa<AllocationInst>(retVal) || isa<PartialApplyInst>(retVal) || isTupleWithAllocsOrPartialApplies(retVal)) {
591613
BlockW.updateBenefit(Benefit, RemovedCallBenefit + 10);
592614
returnsAllocation = true;
593615
}

lib/SILOptimizer/Utils/PerformanceInlinerUtils.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,30 @@ static SILValue getMember(SILInstruction *inst, ProjectionPath &projStack) {
117117
return SILValue();
118118
}
119119

120+
SILValue swift::stripFunctionConversions(SILValue val) {
121+
SILValue result = nullptr;
122+
123+
for (;;) {
124+
if (auto ti = dyn_cast<ThinToThickFunctionInst>(val)) {
125+
val = ti->getOperand();
126+
result = val;
127+
continue;
128+
} else if (auto cfi = dyn_cast<ConvertFunctionInst>(val)) {
129+
val = cfi->getOperand();
130+
result = val;
131+
continue;
132+
} else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(val)) {
133+
val = cvt->getOperand();
134+
result = val;
135+
continue;
136+
} else {
137+
break;
138+
}
139+
}
140+
141+
return result;
142+
}
143+
120144
SILInstruction *ConstantTracker::getDef(SILValue val,
121145
ProjectionPath &projStack) {
122146

@@ -137,14 +161,8 @@ SILInstruction *ConstantTracker::getDef(SILValue val,
137161
// A value loaded from memory.
138162
val = loadedVal;
139163
continue;
140-
} else if (auto ti = dyn_cast<ThinToThickFunctionInst>(inst)) {
141-
val = ti->getOperand();
142-
continue;
143-
} else if (auto cfi = dyn_cast<ConvertFunctionInst>(inst)) {
144-
val = cfi->getOperand();
145-
continue;
146-
} else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(inst)) {
147-
val = cvt->getOperand();
164+
} else if (auto base = stripFunctionConversions(inst)) {
165+
val = base;
148166
continue;
149167
}
150168
return inst;

test/AutoDiff/SILOptimizer/vjp_and_pullback_inlining.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@ import Glibc
1212
import Foundation
1313
#endif
1414

15+
// ======================== VJPs ======================== //
16+
@differentiable(reverse)
17+
@_silgen_name("simple_vjp")
18+
func simple_vjp(x: Float) -> Float {
19+
let a = x * x;
20+
let b = x + x;
21+
let c = x * a;
22+
let d = a + b;
23+
let e = b * c;
24+
25+
return a * b / c + d - e ;
26+
}
27+
28+
@inline(never)
29+
@_silgen_name("caller_of_simple_vjp")
30+
func caller_of_simple_vjp() -> Float {
31+
gradient(at: Float(4), of: simple_vjp)
32+
}
33+
34+
// CHECK: decision {{{.*}}, b=30, {{.*}}} simple_vjpTJrSpSr
35+
// CHECK-NEXT: "simple_vjpTJrSpSr" inlined into "caller_of_simple_vjp"
1536

1637
// ======================== Pullback w/ control-flow ======================== //
1738

0 commit comments

Comments
 (0)