1111// ===----------------------------------------------------------------------===//
1212
1313#include " llvm/CodeGen/TargetInstrInfo.h"
14+ #include " llvm/ADT/SmallSet.h"
1415#include " llvm/ADT/StringExtras.h"
1516#include " llvm/BinaryFormat/Dwarf.h"
1617#include " llvm/CodeGen/MachineCombinerPattern.h"
@@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer(
4243 " disable-sched-hazard" , cl::Hidden, cl::init(false ),
4344 cl::desc(" Disable hazard detection during preRA scheduling" ));
4445
46+ static cl::opt<bool > EnableAccReassociation (
47+ " acc-reassoc" , cl::Hidden, cl::init(true ),
48+ cl::desc(" Enable reassociation of accumulation chains" ));
49+
50+ static cl::opt<unsigned int >
51+ MinAccumulatorDepth (" acc-min-depth" , cl::Hidden, cl::init(8 ),
52+ cl::desc(" Minimum length of accumulator chains "
53+ " required for the optimization to kick in" ));
54+
55+ static cl::opt<unsigned int > MaxAccumulatorWidth (
56+ " acc-max-width" , cl::Hidden, cl::init(3 ),
57+ cl::desc(" Maximum number of branches in the accumulator tree" ));
58+
4559TargetInstrInfo::~TargetInstrInfo () = default ;
4660
4761const TargetRegisterClass*
@@ -897,6 +911,154 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
897911 hasReassociableSibling (Inst, Commuted);
898912}
899913
914+ // Utility routine that checks if \param MO is defined by an
915+ // \param CombineOpc instruction in the basic block \param MBB.
916+ // If \param CombineOpc is not provided, the OpCode check will
917+ // be skipped.
918+ static bool canCombine (MachineBasicBlock &MBB, MachineOperand &MO,
919+ unsigned CombineOpc = 0 ) {
920+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
921+ MachineInstr *MI = nullptr ;
922+
923+ if (MO.isReg () && MO.getReg ().isVirtual ())
924+ MI = MRI.getUniqueVRegDef (MO.getReg ());
925+ // And it needs to be in the trace (otherwise, it won't have a depth).
926+ if (!MI || MI->getParent () != &MBB ||
927+ ((unsigned )MI->getOpcode () != CombineOpc && CombineOpc != 0 ))
928+ return false ;
929+ // Must only used by the user we combine with.
930+ if (!MRI.hasOneNonDBGUse (MI->getOperand (0 ).getReg ()))
931+ return false ;
932+
933+ return true ;
934+ }
935+
936+ // A chain of accumulation instructions will be selected IFF:
937+ // 1. All the accumulation instructions in the chain have the same opcode,
938+ // besides the first that has a slightly different opcode because it does
939+ // not accumulate into a register.
940+ // 2. All the instructions in the chain are combinable (have a single use
941+ // which itself is part of the chain).
942+ // 3. Meets the required minimum length.
943+ void TargetInstrInfo::getAccumulatorChain (
944+ MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
945+ // Walk up the chain of accumulation instructions and collect them in the
946+ // vector.
947+ MachineBasicBlock &MBB = *CurrentInstr->getParent ();
948+ const MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
949+ unsigned AccumulatorOpcode = CurrentInstr->getOpcode ();
950+ std::optional<unsigned > ChainStartOpCode =
951+ getAccumulationStartOpcode (AccumulatorOpcode);
952+
953+ if (!ChainStartOpCode.has_value ())
954+ return ;
955+
956+ // Push the first accumulator result to the start of the chain.
957+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
958+
959+ // Collect the accumulator input register from all instructions in the chain.
960+ while (CurrentInstr &&
961+ canCombine (MBB, CurrentInstr->getOperand (1 ), AccumulatorOpcode)) {
962+ Chain.push_back (CurrentInstr->getOperand (1 ).getReg ());
963+ CurrentInstr = MRI.getUniqueVRegDef (CurrentInstr->getOperand (1 ).getReg ());
964+ }
965+
966+ // Add the instruction at the top of the chain.
967+ if (CurrentInstr->getOpcode () == AccumulatorOpcode &&
968+ canCombine (MBB, CurrentInstr->getOperand (1 )))
969+ Chain.push_back (CurrentInstr->getOperand (1 ).getReg ());
970+ }
971+
972+ // / Find chains of accumulations that can be rewritten as a tree for increased
973+ // / ILP.
974+ bool TargetInstrInfo::getAccumulatorReassociationPatterns (
975+ MachineInstr &Root, SmallVectorImpl<unsigned > &Patterns) const {
976+ if (!EnableAccReassociation)
977+ return false ;
978+
979+ unsigned Opc = Root.getOpcode ();
980+ if (!isAccumulationOpcode (Opc))
981+ return false ;
982+
983+ // Verify that this is the end of the chain.
984+ MachineBasicBlock &MBB = *Root.getParent ();
985+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
986+ if (!MRI.hasOneNonDBGUser (Root.getOperand (0 ).getReg ()))
987+ return false ;
988+
989+ auto User = MRI.use_instr_begin (Root.getOperand (0 ).getReg ());
990+ if (User->getOpcode () == Opc)
991+ return false ;
992+
993+ // Walk up the use chain and collect the reduction chain.
994+ SmallVector<Register, 32 > Chain;
995+ getAccumulatorChain (&Root, Chain);
996+
997+ // Reject chains which are too short to be worth modifying.
998+ if (Chain.size () < MinAccumulatorDepth)
999+ return false ;
1000+
1001+ // Check if the MBB this instruction is a part of contains any other chains.
1002+ // If so, don't apply it.
1003+ SmallSetVector<Register, 32 > ReductionChain (Chain.begin (), Chain.end ());
1004+ for (const auto &I : MBB) {
1005+ if (I.getOpcode () == Opc &&
1006+ !ReductionChain.contains (I.getOperand (0 ).getReg ()))
1007+ return false ;
1008+ }
1009+
1010+ Patterns.push_back (MachineCombinerPattern::ACC_CHAIN);
1011+ return true ;
1012+ }
1013+
1014+ // Reduce branches of the accumulator tree by adding them together.
1015+ void TargetInstrInfo::reduceAccumulatorTree (
1016+ SmallVectorImpl<Register> &RegistersToReduce,
1017+ SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
1018+ MachineInstr &Root, MachineRegisterInfo &MRI,
1019+ DenseMap<unsigned , unsigned > &InstrIdxForVirtReg,
1020+ Register ResultReg) const {
1021+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
1022+ SmallVector<Register, 8 > NewRegs;
1023+
1024+ // Get the opcode for the reduction instruction we will need to build.
1025+ // If for some reason it is not defined, early exit and don't apply this.
1026+ unsigned ReduceOpCode = getReduceOpcodeForAccumulator (Root.getOpcode ());
1027+
1028+ for (unsigned int i = 1 ; i <= (RegistersToReduce.size () / 2 ); i += 2 ) {
1029+ auto RHS = RegistersToReduce[i - 1 ];
1030+ auto LHS = RegistersToReduce[i];
1031+ Register Dest;
1032+ // If we are reducing 2 registers, reuse the original result register.
1033+ if (RegistersToReduce.size () == 2 )
1034+ Dest = ResultReg;
1035+ // Otherwise, create a new virtual register to hold the partial sum.
1036+ else {
1037+ auto NewVR = MRI.createVirtualRegister (
1038+ MRI.getRegClass (Root.getOperand (0 ).getReg ()));
1039+ Dest = NewVR;
1040+ NewRegs.push_back (Dest);
1041+ InstrIdxForVirtReg.insert (std::make_pair (Dest, InsInstrs.size ()));
1042+ }
1043+
1044+ // Create the new reduction instruction.
1045+ MachineInstrBuilder MIB =
1046+ BuildMI (MF, MIMetadata (Root), TII->get (ReduceOpCode), Dest)
1047+ .addReg (RHS, getKillRegState (true ))
1048+ .addReg (LHS, getKillRegState (true ));
1049+ // Copy any flags needed from the original instruction.
1050+ MIB->setFlags (Root.getFlags ());
1051+ InsInstrs.push_back (MIB);
1052+ }
1053+
1054+ // If the number of registers to reduce is odd, add the remaining register to
1055+ // the vector of registers to reduce.
1056+ if (RegistersToReduce.size () % 2 != 0 )
1057+ NewRegs.push_back (RegistersToReduce[RegistersToReduce.size () - 1 ]);
1058+
1059+ RegistersToReduce = NewRegs;
1060+ }
1061+
9001062// The concept of the reassociation pass is that these operations can benefit
9011063// from this kind of transformation:
9021064//
@@ -936,6 +1098,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns(
9361098 }
9371099 return true ;
9381100 }
1101+ if (getAccumulatorReassociationPatterns (Root, Patterns))
1102+ return true ;
9391103
9401104 return false ;
9411105}
@@ -947,7 +1111,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const {
9471111
9481112CombinerObjective
9491113TargetInstrInfo::getCombinerObjective (unsigned Pattern) const {
950- return CombinerObjective::Default;
1114+ switch (Pattern) {
1115+ case MachineCombinerPattern::ACC_CHAIN:
1116+ return CombinerObjective::MustReduceDepth;
1117+ default :
1118+ return CombinerObjective::Default;
1119+ }
9511120}
9521121
9531122std::pair<unsigned , unsigned >
@@ -1250,19 +1419,98 @@ void TargetInstrInfo::genAlternativeCodeSequence(
12501419 SmallVectorImpl<MachineInstr *> &DelInstrs,
12511420 DenseMap<unsigned , unsigned > &InstIdxForVirtReg) const {
12521421 MachineRegisterInfo &MRI = Root.getMF ()->getRegInfo ();
1422+ MachineBasicBlock &MBB = *Root.getParent ();
1423+ MachineFunction &MF = *MBB.getParent ();
1424+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
12531425
1254- // Select the previous instruction in the sequence based on the input pattern.
1255- std::array<unsigned , 5 > OperandIndices;
1256- getReassociateOperandIndices (Root, Pattern, OperandIndices);
1257- MachineInstr *Prev =
1258- MRI.getUniqueVRegDef (Root.getOperand (OperandIndices[0 ]).getReg ());
1426+ switch (Pattern) {
1427+ case MachineCombinerPattern::REASSOC_AX_BY:
1428+ case MachineCombinerPattern::REASSOC_AX_YB:
1429+ case MachineCombinerPattern::REASSOC_XA_BY:
1430+ case MachineCombinerPattern::REASSOC_XA_YB: {
1431+ // Select the previous instruction in the sequence based on the input
1432+ // pattern.
1433+ std::array<unsigned , 5 > OperandIndices;
1434+ getReassociateOperandIndices (Root, Pattern, OperandIndices);
1435+ MachineInstr *Prev =
1436+ MRI.getUniqueVRegDef (Root.getOperand (OperandIndices[0 ]).getReg ());
1437+
1438+ // Don't reassociate if Prev and Root are in different blocks.
1439+ if (Prev->getParent () != Root.getParent ())
1440+ return ;
12591441
1260- // Don't reassociate if Prev and Root are in different blocks.
1261- if (Prev->getParent () != Root.getParent ())
1262- return ;
1442+ reassociateOps (Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1443+ InstIdxForVirtReg);
1444+ break ;
1445+ }
1446+ case MachineCombinerPattern::ACC_CHAIN: {
1447+ SmallVector<Register, 32 > ChainRegs;
1448+ getAccumulatorChain (&Root, ChainRegs);
1449+ unsigned int Depth = ChainRegs.size ();
1450+ assert (MaxAccumulatorWidth > 1 &&
1451+ " Max accumulator width set to illegal value" );
1452+ unsigned int MaxWidth = Log2_32 (Depth) < MaxAccumulatorWidth
1453+ ? Log2_32 (Depth)
1454+ : MaxAccumulatorWidth;
1455+
1456+ // Walk down the chain and rewrite it as a tree.
1457+ for (auto IndexedReg : llvm::enumerate (llvm::reverse (ChainRegs))) {
1458+ // No need to rewrite the first node, it is already perfect as it is.
1459+ if (IndexedReg.index () == 0 )
1460+ continue ;
1461+
1462+ MachineInstr *Instr = MRI.getUniqueVRegDef (IndexedReg.value ());
1463+ MachineInstrBuilder MIB;
1464+ Register AccReg;
1465+ if (IndexedReg.index () < MaxWidth) {
1466+ // Now we need to create new instructions for the first row.
1467+ AccReg = Instr->getOperand (0 ).getReg ();
1468+ unsigned OpCode = getAccumulationStartOpcode (Root.getOpcode ());
1469+
1470+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (OpCode), AccReg)
1471+ .addReg (Instr->getOperand (2 ).getReg (),
1472+ getKillRegState (Instr->getOperand (2 ).isKill ()))
1473+ .addReg (Instr->getOperand (3 ).getReg (),
1474+ getKillRegState (Instr->getOperand (3 ).isKill ()));
1475+ } else {
1476+ // For the remaining cases, we need to use an output register of one of
1477+ // the newly inserted instuctions as operand 1
1478+ AccReg = Instr->getOperand (0 ).getReg () == Root.getOperand (0 ).getReg ()
1479+ ? MRI.createVirtualRegister (
1480+ MRI.getRegClass (Root.getOperand (0 ).getReg ()))
1481+ : Instr->getOperand (0 ).getReg ();
1482+ assert (IndexedReg.index () >= MaxWidth);
1483+ auto AccumulatorInput =
1484+ ChainRegs[Depth - (IndexedReg.index () - MaxWidth) - 1 ];
1485+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (Instr->getOpcode ()),
1486+ AccReg)
1487+ .addReg (AccumulatorInput, getKillRegState (true ))
1488+ .addReg (Instr->getOperand (2 ).getReg (),
1489+ getKillRegState (Instr->getOperand (2 ).isKill ()))
1490+ .addReg (Instr->getOperand (3 ).getReg (),
1491+ getKillRegState (Instr->getOperand (3 ).isKill ()));
1492+ }
12631493
1264- reassociateOps (Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1265- InstIdxForVirtReg);
1494+ MIB->setFlags (Instr->getFlags ());
1495+ InstIdxForVirtReg.insert (std::make_pair (AccReg, InsInstrs.size ()));
1496+ InsInstrs.push_back (MIB);
1497+ DelInstrs.push_back (Instr);
1498+ }
1499+
1500+ SmallVector<Register, 8 > RegistersToReduce;
1501+ for (unsigned i = (InsInstrs.size () - MaxWidth); i < InsInstrs.size ();
1502+ ++i) {
1503+ auto Reg = InsInstrs[i]->getOperand (0 ).getReg ();
1504+ RegistersToReduce.push_back (Reg);
1505+ }
1506+
1507+ while (RegistersToReduce.size () > 1 )
1508+ reduceAccumulatorTree (RegistersToReduce, InsInstrs, MF, Root, MRI,
1509+ InstIdxForVirtReg, Root.getOperand (0 ).getReg ());
1510+
1511+ break ;
1512+ }
1513+ }
12661514}
12671515
12681516MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy () const {
0 commit comments