Skip to content

Commit 8045bf9

Browse files
committed
[FuncSpec] Support function specialization across multiple arguments.
The current implementation of Function Specialization does not allow specializing more than one arguments per function call, which is a limitation I am lifting with this patch. My main challenge was to choose the most suitable ADT for storing the specializations. We need an associative container for binding all the actual arguments of a specialization to the function call. We also need a consistent iteration order across executions. Lastly we want to be able to sort the entries by Gain and reject the least profitable ones. MapVector fits the bill but not quite; erasing elements is expensive and using stable_sort messes up the indices to the underlying vector. I am therefore using the underlying vector directly after calculating the Gain. Differential Revision: https://reviews.llvm.org/D119880
1 parent 4ca111d commit 8045bf9

File tree

5 files changed

+306
-95
lines changed

5 files changed

+306
-95
lines changed

llvm/include/llvm/Transforms/Utils/SCCPSolver.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,14 @@ class SCCPSolver {
151151
/// Return a reference to the set of argument tracked functions.
152152
SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions();
153153

154-
/// Mark the constant argument of a new function specialization. \p F points
155-
/// to the cloned function and \p Arg represents the constant argument as a
156-
/// pair of {formal,actual} values (the formal argument is associated with the
157-
/// original function definition). All other arguments of the specialization
158-
/// inherit the lattice state of their corresponding values in the original
159-
/// function.
160-
void markArgInFuncSpecialization(Function *F, const ArgInfo &Arg);
154+
/// Mark the constant arguments of a new function specialization. \p F points
155+
/// to the cloned function and \p Args contains a list of constant arguments
156+
/// represented as pairs of {formal,actual} values (the formal argument is
157+
/// associated with the original function definition). All other arguments of
158+
/// the specialization inherit the lattice state of their corresponding values
159+
/// in the original function.
160+
void markArgInFuncSpecialization(Function *F,
161+
const SmallVectorImpl<ArgInfo> &Args);
161162

162163
/// Mark all of the blocks in function \p F non-executable. Clients can used
163164
/// this method to erase a function from the module (e.g., if it has been

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp

Lines changed: 97 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ static cl::opt<bool> SpecializeOnAddresses(
9999
"func-specialization-on-address", cl::init(false), cl::Hidden,
100100
cl::desc("Enable function specialization on the address of global values"));
101101

102-
// TODO: This needs checking to see the impact on compile-times, which is why
103-
// this is off by default for now.
102+
// Disabled by default as it can significantly increase compilation times.
103+
// Running nikic's compile time tracker on x86 with instruction count as the
104+
// metric shows 3-4% regression for SPASS while being neutral for all other
105+
// benchmarks of the llvm test suite.
106+
//
107+
// https://llvm-compile-time-tracker.com
108+
// https://github.com/nikic/llvm-compile-time-tracker
104109
static cl::opt<bool> EnableSpecializationForLiteralConstant(
105110
"function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
106111
cl::desc("Enable specialization of functions that take a literal constant "
@@ -110,17 +115,17 @@ namespace {
110115
// Bookkeeping struct to pass data from the analysis and profitability phase
111116
// to the actual transform helper functions.
112117
struct SpecializationInfo {
113-
ArgInfo Arg; // Stores the {formal,actual} argument pair.
114-
InstructionCost Gain; // Profitability: Gain = Bonus - Cost.
115-
116-
SpecializationInfo(Argument *A, Constant *C, InstructionCost G)
117-
: Arg(A, C), Gain(G){};
118+
SmallVector<ArgInfo, 8> Args; // Stores the {formal,actual} argument pairs.
119+
InstructionCost Gain; // Profitability: Gain = Bonus - Cost.
118120
};
119121
} // Anonymous namespace
120122

121123
using FuncList = SmallVectorImpl<Function *>;
122-
using ConstList = SmallVector<Constant *>;
123-
using SpecializationList = SmallVector<SpecializationInfo>;
124+
using CallArgBinding = std::pair<CallBase *, Constant *>;
125+
using CallSpecBinding = std::pair<CallBase *, SpecializationInfo>;
126+
// We are using MapVector because it guarantees deterministic iteration
127+
// order across executions.
128+
using SpecializationMap = SmallMapVector<CallBase *, SpecializationInfo, 8>;
124129

125130
// Helper to check if \p LV is either a constant or a constant
126131
// range with a single element. This should cover exactly the same cases as the
@@ -307,17 +312,15 @@ class FunctionSpecializer {
307312
LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
308313
<< F->getName() << " is " << Cost << "\n");
309314

310-
SpecializationList Specializations;
311-
calculateGains(F, Cost, Specializations);
312-
if (Specializations.empty()) {
313-
LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n");
315+
SmallVector<CallSpecBinding, 8> Specializations;
316+
if (!calculateGains(F, Cost, Specializations)) {
317+
LLVM_DEBUG(dbgs() << "FnSpecialization: No possible constants found\n");
314318
continue;
315319
}
316320

317-
for (SpecializationInfo &S : Specializations) {
318-
specializeFunction(F, S, WorkList);
319-
Changed = true;
320-
}
321+
Changed = true;
322+
for (auto &Entry : Specializations)
323+
specializeFunction(F, Entry.second, WorkList);
321324
}
322325

323326
updateSpecializedFuncs(Candidates, WorkList);
@@ -392,72 +395,79 @@ class FunctionSpecializer {
392395
return Clone;
393396
}
394397

395-
/// This function decides whether it's worthwhile to specialize function \p F
396-
/// based on the known constant values its arguments can take on, i.e. it
397-
/// calculates a gain and returns a list of actual arguments that are deemed
398-
/// profitable to specialize. Specialization is performed on the first
399-
/// interesting argument. Specializations based on additional arguments will
400-
/// be evaluated on following iterations of the main IPSCCP solve loop.
401-
void calculateGains(Function *F, InstructionCost Cost,
402-
SpecializationList &WorkList) {
398+
/// This function decides whether it's worthwhile to specialize function
399+
/// \p F based on the known constant values its arguments can take on. It
400+
/// only discovers potential specialization opportunities without actually
401+
/// applying them.
402+
///
403+
/// \returns true if any specializations have been found.
404+
bool calculateGains(Function *F, InstructionCost Cost,
405+
SmallVectorImpl<CallSpecBinding> &WorkList) {
406+
SpecializationMap Specializations;
403407
// Determine if we should specialize the function based on the values the
404408
// argument can take on. If specialization is not profitable, we continue
405409
// on to the next argument.
406410
for (Argument &FormalArg : F->args()) {
407411
// Determine if this argument is interesting. If we know the argument can
408412
// take on any constant values, they are collected in Constants.
409-
ConstList ActualArgs;
413+
SmallVector<CallArgBinding, 8> ActualArgs;
410414
if (!isArgumentInteresting(&FormalArg, ActualArgs)) {
411415
LLVM_DEBUG(dbgs() << "FnSpecialization: Argument "
412416
<< FormalArg.getNameOrAsOperand()
413417
<< " is not interesting\n");
414418
continue;
415419
}
416420

417-
for (auto *ActualArg : ActualArgs) {
418-
InstructionCost Gain =
419-
ForceFunctionSpecialization
420-
? 1
421-
: getSpecializationBonus(&FormalArg, ActualArg) - Cost;
421+
for (const auto &Entry : ActualArgs) {
422+
CallBase *Call = Entry.first;
423+
Constant *ActualArg = Entry.second;
422424

423-
if (Gain <= 0)
424-
continue;
425-
WorkList.push_back({&FormalArg, ActualArg, Gain});
426-
}
425+
auto I = Specializations.insert({Call, SpecializationInfo()});
426+
SpecializationInfo &S = I.first->second;
427427

428-
if (WorkList.empty())
429-
continue;
430-
431-
// Sort the candidates in descending order.
432-
llvm::stable_sort(WorkList, [](const SpecializationInfo &L,
433-
const SpecializationInfo &R) {
434-
return L.Gain > R.Gain;
435-
});
436-
437-
// Truncate the worklist to 'MaxClonesThreshold' candidates if
438-
// necessary.
439-
if (WorkList.size() > MaxClonesThreshold) {
440-
LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
441-
<< "the maximum number of clones threshold.\n"
442-
<< "FnSpecialization: Truncating worklist to "
443-
<< MaxClonesThreshold << " candidates.\n");
444-
WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
428+
if (I.second)
429+
S.Gain = ForceFunctionSpecialization ? 1 : 0 - Cost;
430+
if (!ForceFunctionSpecialization)
431+
S.Gain += getSpecializationBonus(&FormalArg, ActualArg);
432+
S.Args.push_back({&FormalArg, ActualArg});
445433
}
434+
}
435+
436+
// Remove unprofitable specializations.
437+
Specializations.remove_if(
438+
[](const auto &Entry) { return Entry.second.Gain <= 0; });
439+
440+
// Clear the MapVector and return the underlying vector.
441+
WorkList = Specializations.takeVector();
442+
443+
// Sort the candidates in descending order.
444+
llvm::stable_sort(WorkList, [](const auto &L, const auto &R) {
445+
return L.second.Gain > R.second.Gain;
446+
});
447+
448+
// Truncate the worklist to 'MaxClonesThreshold' candidates if necessary.
449+
if (WorkList.size() > MaxClonesThreshold) {
450+
LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
451+
<< "the maximum number of clones threshold.\n"
452+
<< "FnSpecialization: Truncating worklist to "
453+
<< MaxClonesThreshold << " candidates.\n");
454+
WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
455+
}
446456

447-
LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
448-
<< F->getName() << "\n";
449-
for (SpecializationInfo &S
450-
: WorkList) {
457+
LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
458+
<< F->getName() << "\n";
459+
for (const auto &Entry
460+
: WorkList) {
461+
dbgs() << "FnSpecialization: Gain = " << Entry.second.Gain
462+
<< "\n";
463+
for (const ArgInfo &Arg : Entry.second.Args)
451464
dbgs() << "FnSpecialization: FormalArg = "
452-
<< S.Arg.Formal->getNameOrAsOperand()
465+
<< Arg.Formal->getNameOrAsOperand()
453466
<< ", ActualArg = "
454-
<< S.Arg.Actual->getNameOrAsOperand()
455-
<< ", Gain = " << S.Gain << "\n";
456-
});
467+
<< Arg.Actual->getNameOrAsOperand() << "\n";
468+
});
457469

458-
// FIXME: Only one argument per function.
459-
break;
460-
}
470+
return !WorkList.empty();
461471
}
462472

463473
bool isCandidateFunction(Function *F) {
@@ -490,12 +500,12 @@ class FunctionSpecializer {
490500
Function *Clone = cloneCandidateFunction(F, Mappings);
491501

492502
// Rewrite calls to the function so that they call the clone instead.
493-
rewriteCallSites(Clone, S.Arg, Mappings);
503+
rewriteCallSites(Clone, S.Args, Mappings);
494504

495505
// Initialize the lattice state of the arguments of the function clone,
496506
// marking the argument on which we specialized the function constant
497507
// with the given value.
498-
Solver.markArgInFuncSpecialization(Clone, S.Arg);
508+
Solver.markArgInFuncSpecialization(Clone, S.Args);
499509

500510
// Mark all the specialized functions
501511
WorkList.push_back(Clone);
@@ -641,7 +651,8 @@ class FunctionSpecializer {
641651
///
642652
/// \returns true if the function should be specialized on the given
643653
/// argument.
644-
bool isArgumentInteresting(Argument *A, ConstList &Constants) {
654+
bool isArgumentInteresting(Argument *A,
655+
SmallVectorImpl<CallArgBinding> &Constants) {
645656
// For now, don't attempt to specialize functions based on the values of
646657
// composite types.
647658
if (!A->getType()->isSingleValueType() || A->user_empty())
@@ -681,7 +692,8 @@ class FunctionSpecializer {
681692

682693
/// Collect in \p Constants all the constant values that argument \p A can
683694
/// take on.
684-
void getPossibleConstants(Argument *A, ConstList &Constants) {
695+
void getPossibleConstants(Argument *A,
696+
SmallVectorImpl<CallArgBinding> &Constants) {
685697
Function *F = A->getParent();
686698

687699
// Iterate over all the call sites of the argument's parent function.
@@ -723,23 +735,24 @@ class FunctionSpecializer {
723735

724736
if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
725737
EnableSpecializationForLiteralConstant))
726-
Constants.push_back(cast<Constant>(V));
738+
Constants.push_back({&CS, cast<Constant>(V)});
727739
}
728740
}
729741

730742
/// Rewrite calls to function \p F to call function \p Clone instead.
731743
///
732744
/// This function modifies calls to function \p F as long as the actual
733-
/// argument matches the one in \p Arg. Note that for recursive calls we
734-
/// need to compare against the cloned formal argument.
745+
/// arguments match those in \p Args. Note that for recursive calls we
746+
/// need to compare against the cloned formal arguments.
735747
///
736748
/// Callsites that have been marked with the MinSize function attribute won't
737749
/// be specialized and rewritten.
738-
void rewriteCallSites(Function *Clone, const ArgInfo &Arg,
750+
void rewriteCallSites(Function *Clone, const SmallVectorImpl<ArgInfo> &Args,
739751
ValueToValueMapTy &Mappings) {
740-
Function *F = Arg.Formal->getParent();
741-
unsigned ArgNo = Arg.Formal->getArgNo();
742-
SmallVector<CallBase *, 4> CallSitesToRewrite;
752+
assert(!Args.empty() && "Specialization without arguments");
753+
Function *F = Args[0].Formal->getParent();
754+
755+
SmallVector<CallBase *, 8> CallSitesToRewrite;
743756
for (auto *U : F->users()) {
744757
if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
745758
continue;
@@ -758,9 +771,16 @@ class FunctionSpecializer {
758771
<< "\n");
759772
if (/* recursive call */
760773
(CS->getFunction() == Clone &&
761-
CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]) ||
774+
all_of(Args,
775+
[CS, &Mappings](const ArgInfo &Arg) {
776+
unsigned ArgNo = Arg.Formal->getArgNo();
777+
return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal];
778+
})) ||
762779
/* normal call */
763-
CS->getArgOperand(ArgNo) == Arg.Actual) {
780+
all_of(Args, [CS](const ArgInfo &Arg) {
781+
unsigned ArgNo = Arg.Formal->getArgNo();
782+
return CS->getArgOperand(ArgNo) == Arg.Actual;
783+
})) {
764784
CS->setCalledFunction(Clone);
765785
Solver.markOverdefined(CS);
766786
}
@@ -891,7 +911,7 @@ bool llvm::runFunctionSpecialization(
891911
// Initially resolve the constants in all the argument tracked functions.
892912
RunSCCPSolver(FuncDecls);
893913

894-
SmallVector<Function *, 2> WorkList;
914+
SmallVector<Function *, 8> WorkList;
895915
unsigned I = 0;
896916
while (FuncSpecializationMaxIters != I++ &&
897917
FS.specializeFunctions(FuncDecls, WorkList)) {

llvm/lib/Transforms/Utils/SCCPSolver.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
450450
return TrackingIncomingArguments;
451451
}
452452

453-
void markArgInFuncSpecialization(Function *F, const ArgInfo &Arg);
453+
void markArgInFuncSpecialization(Function *F,
454+
const SmallVectorImpl<ArgInfo> &Args);
454455

455456
void markFunctionUnreachable(Function *F) {
456457
for (auto &BB : *F)
@@ -524,21 +525,24 @@ Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
524525
return nullptr;
525526
}
526527

527-
void SCCPInstVisitor::markArgInFuncSpecialization(Function *F,
528-
const ArgInfo &Arg) {
529-
assert(F->arg_size() == Arg.Formal->getParent()->arg_size() &&
528+
void SCCPInstVisitor::markArgInFuncSpecialization(
529+
Function *F, const SmallVectorImpl<ArgInfo> &Args) {
530+
assert(!Args.empty() && "Specialization without arguments");
531+
assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() &&
530532
"Functions should have the same number of arguments");
531533

534+
auto Iter = Args.begin();
532535
Argument *NewArg = F->arg_begin();
533-
Argument *OldArg = Arg.Formal->getParent()->arg_begin();
536+
Argument *OldArg = Args[0].Formal->getParent()->arg_begin();
534537
for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) {
535538

536539
LLVM_DEBUG(dbgs() << "SCCP: Marking argument "
537540
<< NewArg->getNameOrAsOperand() << "\n");
538541

539-
if (OldArg == Arg.Formal) {
542+
if (OldArg == Iter->Formal) {
540543
// Mark the argument constants in the new function.
541-
markConstant(NewArg, Arg.Actual);
544+
markConstant(NewArg, Iter->Actual);
545+
++Iter;
542546
} else if (ValueState.count(OldArg)) {
543547
// For the remaining arguments in the new function, copy the lattice state
544548
// over from the old function.
@@ -1717,8 +1721,9 @@ SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() {
17171721
return Visitor->getArgumentTrackedFunctions();
17181722
}
17191723

1720-
void SCCPSolver::markArgInFuncSpecialization(Function *F, const ArgInfo &Arg) {
1721-
Visitor->markArgInFuncSpecialization(F, Arg);
1724+
void SCCPSolver::markArgInFuncSpecialization(
1725+
Function *F, const SmallVectorImpl<ArgInfo> &Args) {
1726+
Visitor->markArgInFuncSpecialization(F, Args);
17221727
}
17231728

17241729
void SCCPSolver::markFunctionUnreachable(Function *F) {

llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ entry:
4646
; CHECK-NEXT: entry:
4747
; CHECK-NEXT: %0 = load i32, i32* @A, align 4
4848
; CHECK-NEXT: %add = add nsw i32 %x, %0
49-
; CHECK-NEXT: %1 = load i32, i32* %c, align 4
49+
; CHECK-NEXT: %1 = load i32, i32* @C, align 4
5050
; CHECK-NEXT: %add1 = add nsw i32 %add, %1
5151
; CHECK-NEXT: ret i32 %add1
5252
; CHECK-NEXT: }
@@ -55,7 +55,7 @@ entry:
5555
; CHECK-NEXT: entry:
5656
; CHECK-NEXT: %0 = load i32, i32* @B, align 4
5757
; CHECK-NEXT: %add = add nsw i32 %x, %0
58-
; CHECK-NEXT: %1 = load i32, i32* %c, align 4
58+
; CHECK-NEXT: %1 = load i32, i32* @D, align 4
5959
; CHECK-NEXT: %add1 = add nsw i32 %add, %1
6060
; CHECK-NEXT: ret i32 %add1
6161
; CHECK-NEXT: }

0 commit comments

Comments
 (0)