@@ -208,6 +208,8 @@ class SILPerformanceInliner {
208
208
llvm::detail::DenseMapPair<swift::SILBasicBlock *, uint64_t >, true >
209
209
&bbIt);
210
210
211
+ bool isAutoDiffLinearMapWithControlFlow (FullApplySite AI);
212
+
211
213
bool isProfitableToInline (
212
214
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
213
215
int &NumCallerBlocks,
@@ -299,6 +301,69 @@ bool SILPerformanceInliner::profileBasedDecision(
299
301
return true ;
300
302
}
301
303
304
+ // Checks if `FAI` can be traced back to a specifically named,
305
+ // input enum function argument. If so, the callsite
306
+ // containing function is a linear map in Swift Autodiff.
307
+ bool SILPerformanceInliner::isAutoDiffLinearMapWithControlFlow (
308
+ FullApplySite FAI) {
309
+ static const std::string LinearMapBranchTracingEnumPrefix = " _AD__" ;
310
+
311
+ auto val = FAI.getCallee ();
312
+
313
+ for (;;) {
314
+ if (auto *inst = dyn_cast<SingleValueInstruction>(val)) {
315
+ if (auto pi = Projection::isObjectProjection (val)) {
316
+ // Extract a member from a struct/tuple/enum.
317
+ val = pi->getOperand (0 );
318
+ continue ;
319
+ } else if (auto ti = dyn_cast<ThinToThickFunctionInst>(inst)) {
320
+ val = ti->getOperand ();
321
+ continue ;
322
+ } else if (auto cfi = dyn_cast<ConvertFunctionInst>(inst)) {
323
+ val = cfi->getOperand ();
324
+ continue ;
325
+ } else if (auto cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(inst)) {
326
+ val = cvt->getOperand ();
327
+ continue ;
328
+ }
329
+ return false ;
330
+ } else if (auto *phiArg = dyn_cast<SILPhiArgument>(val)) {
331
+ if (auto *predBB = phiArg->getParent ()->getSinglePredecessorBlock ()) {
332
+ // The terminator of this predecessor block must either be a
333
+ // (conditional) branch instruction or a switch_enum.
334
+ if (auto *bi = dyn_cast<BranchInst>(predBB->getTerminator ())) {
335
+ val = bi->getArg (phiArg->getIndex ());
336
+ continue ;
337
+ } else if (auto *cbi =
338
+ dyn_cast<CondBranchInst>(predBB->getTerminator ())) {
339
+ val = cbi->getArgForDestBB (phiArg->getParent (), phiArg->getIndex ());
340
+ continue ;
341
+ } else if (auto *sei =
342
+ dyn_cast<SwitchEnumInst>(predBB->getTerminator ())) {
343
+ val = sei->getOperand ();
344
+ continue ;
345
+ }
346
+ return false ;
347
+ }
348
+ }
349
+ break ;
350
+ }
351
+
352
+ // If `val` now points to a function argument then we have successfully traced
353
+ // the callee back to a function argument.
354
+ //
355
+ // We now need to check if this argument is an enum and named like an autodiff
356
+ // branch tracing enum.
357
+ if (auto *arg = dyn_cast<SILFunctionArgument>(val)) {
358
+ if (auto *enumDecl = arg->getType ().getEnumOrBoundGenericEnum ()) {
359
+ return enumDecl->getName ().str ().startswith (
360
+ LinearMapBranchTracingEnumPrefix);
361
+ }
362
+ }
363
+
364
+ return false ;
365
+ }
366
+
302
367
bool SILPerformanceInliner::isProfitableToInline (
303
368
FullApplySite AI, Weight CallerWeight, ConstantTracker &callerTracker,
304
369
int &NumCallerBlocks,
@@ -413,6 +478,19 @@ bool SILPerformanceInliner::isProfitableToInline(
413
478
SILInstruction *def = constTracker.getDefInCaller (FAI.getCallee ());
414
479
if (def && (isa<FunctionRefInst>(def) || isa<PartialApplyInst>(def)))
415
480
BlockW.updateBenefit (Benefit, RemovedClosureBenefit);
481
+ else if (isAutoDiffLinearMapWithControlFlow (FAI)) {
482
+ // For linear maps in Swift Autodiff, callees may be passed as an
483
+ // argument, however, they may be hidden behind a branch-tracing
484
+ // enum (tracing execution flow of the original function).
485
+ //
486
+ // If we can establish that we are inside of a Swift Autodiff linear
487
+ // map and that the branch tracing input enum is wrapping pullback
488
+ // closures, then we can update this function's benefit with
489
+ // `RemovedClosureBenefit` because inlining will (probably) eliminate
490
+ // the closure.
491
+ BlockW.updateBenefit (Benefit, RemovedClosureBenefit);
492
+ }
493
+
416
494
// Check if inlining the callee would allow for further
417
495
// optimizations like devirtualization or generic specialization.
418
496
if (!def)
0 commit comments