@@ -208,6 +208,8 @@ class SILPerformanceInliner {
208208 llvm::detail::DenseMapPair<swift::SILBasicBlock *, uint64_t >, true >
209209 &bbIt);
210210
211+ bool isAutoDiffLinearMapWithControlFlow (FullApplySite AI);
212+
211213 bool isProfitableToInline (
212214 FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
213215 int &NumCallerBlocks,
@@ -299,6 +301,69 @@ bool SILPerformanceInliner::profileBasedDecision(
299301 return true ;
300302}
301303
304+ // Checks if `FAI` can be traced back to a specifically named,
305+ // input enum function argument. If so, the callsite
306+ // containing function is a linear map in Swift Autodiff.
307+ bool SILPerformanceInliner::isAutoDiffLinearMapWithControlFlow (
308+ FullApplySite FAI) {
309+ static const std::string LinearMapBranchTracingEnumPrefix = " _AD__" ;
310+
311+ auto val = FAI.getCallee ();
312+
313+ for (;;) {
314+ if (auto *inst = dyn_cast<SingleValueInstruction>(val)) {
315+ if (auto pi = Projection::isObjectProjection (val)) {
316+ // Extract a member from a struct/tuple/enum.
317+ val = pi->getOperand (0 );
318+ continue ;
319+ } else if (auto ti = dyn_cast<ThinToThickFunctionInst>(inst)) {
320+ val = ti->getOperand ();
321+ continue ;
322+ } else if (auto cfi = dyn_cast<ConvertFunctionInst>(inst)) {
323+ val = cfi->getOperand ();
324+ continue ;
325+ } else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(inst)) {
326+ val = cvt->getOperand ();
327+ continue ;
328+ }
329+ return false ;
330+ } else if (auto *phiArg = dyn_cast<SILPhiArgument>(val)) {
331+ if (auto *predBB = phiArg->getParent ()->getSinglePredecessorBlock ()) {
332+ // The terminator of this predecessor block must either be a
333+ // (conditional) branch instruction or a switch_enum.
334+ if (auto *bi = dyn_cast<BranchInst>(predBB->getTerminator ())) {
335+ val = bi->getArg (phiArg->getIndex ());
336+ continue ;
337+ } else if (auto *cbi =
338+ dyn_cast<CondBranchInst>(predBB->getTerminator ())) {
339+ val = cbi->getArgForDestBB (phiArg->getParent (), phiArg->getIndex ());
340+ continue ;
341+ } else if (auto *sei =
342+ dyn_cast<SwitchEnumInst>(predBB->getTerminator ())) {
343+ val = sei->getOperand ();
344+ continue ;
345+ }
346+ return false ;
347+ }
348+ }
349+ break ;
350+ }
351+
352+ // If `val` now points to a function argument then we have successfully traced
353+ // the callee back to a function argument.
354+ //
355+ // We now need to check if this argument is an enum and named like an autodiff
356+ // branch tracing enum.
357+ if (auto *arg = dyn_cast<SILFunctionArgument>(val)) {
358+ if (auto *enumDecl = arg->getType ().getEnumOrBoundGenericEnum ()) {
359+ return enumDecl->getName ().str ().startswith (
360+ LinearMapBranchTracingEnumPrefix);
361+ }
362+ }
363+
364+ return false ;
365+ }
366+
302367bool SILPerformanceInliner::isProfitableToInline (
303368 FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
304369 int &NumCallerBlocks,
@@ -413,6 +478,19 @@ bool SILPerformanceInliner::isProfitableToInline(
413478 SILInstruction *def = constTracker.getDefInCaller (FAI.getCallee ());
414479 if (def && (isa<FunctionRefInst>(def) || isa<PartialApplyInst>(def)))
415480 BlockW.updateBenefit (Benefit, RemovedClosureBenefit);
481+ else if (isAutoDiffLinearMapWithControlFlow (FAI)) {
482+ // For linear maps in Swift Autodiff, callees may be passed as an
483+ // argument, however, they may be hidden behind a branch-tracing
484+ // enum (tracing execution flow of the original function).
485+ //
486+ // If we can establish that we are inside of a Swift Autodiff linear
487+ // map and that the branch tracing input enum is wrapping pullback
488+ // closures, then we can update this function's benefit with
489+ // `RemovedClosureBenefit` because inlining will (probably) eliminate
490+ // the closure.
491+ BlockW.updateBenefit (Benefit, RemovedClosureBenefit);
492+ }
493+
416494 // Check if inlining the callee would allow for further
417495 // optimizations like devirtualization or generic specialization.
418496 if (!def)
0 commit comments