Skip to content

Commit 1890225

Browse files
author
Yonghong Song
committed
[BPF] Add jump table support with switch statements and computed goto
This patch adds jump table support. A new insn 'gotox <reg>' is added to allow goto through a register. The register represents the address in the current section. Code: int foo(int a, int b) { __label__ l1, l2, l3, l4; void *jt1[] = {[0]=&&l1, [1]=&&l2}; void *jt2[] = {[0]=&&l3, [1]=&&l4}; int ret = 0; goto *jt1[a % 2]; l1: ret += 1; l2: ret += 3; goto *jt2[b % 2]; l3: ret += 5; l4: ret += 7; return ret; } Compilation Command: clang --target=bpf -O2 -S test2.c But I observed that the above compilation command actually hangs with BranchFolding. If I did the following: bool BranchFolderLegacy::runOnMachineFunction(MachineFunction &MF) { + if (true) return false; if (skipFunction(MF.getFunction())) return false; The compilation can be done successful. I roughly took a look at the dbg trace, looks like there is an infinite loop in BranchFolding. This patch is tested on top of commit: commit 58c3aff (origin/main, origin/HEAD, main) Author: Michał Górny <[email protected]> Date: Sun Jul 20 05:26:51 2025 +0200 [libclc] Expose `prepare_builtins_*` variables in top-level CMakeLists (llvm#149657)
1 parent 58c3aff commit 1890225

File tree

7 files changed

+105
-2
lines changed

7 files changed

+105
-2
lines changed

llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
234234
.Case("callx", true)
235235
.Case("goto", true)
236236
.Case("gotol", true)
237+
.Case("gotox", true)
237238
.Case("may_goto", true)
238239
.Case("*", true)
239240
.Case("exit", true)

llvm/lib/Target/BPF/BPFISelLowering.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
3838
cl::Hidden, cl::init(false),
3939
cl::desc("Expand memcpy into load/store pairs in order"));
4040

41+
static cl::opt<unsigned> BPFMinimumJumpTableEntries(
42+
"bpf-min-jump-table-entries", cl::init(4), cl::Hidden,
43+
cl::desc("Set minimum number of entries to use a jump table on BPF"));
44+
4145
static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
4246
SDValue Val = {}) {
4347
std::string Str;
@@ -67,12 +71,13 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
6771

6872
setOperationAction(ISD::BR_CC, MVT::i64, Custom);
6973
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
70-
setOperationAction(ISD::BRIND, MVT::Other, Expand);
7174
setOperationAction(ISD::BRCOND, MVT::Other, Expand);
7275

7376
setOperationAction(ISD::TRAP, MVT::Other, Custom);
7477

75-
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom);
78+
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable,
79+
ISD::BlockAddress},
80+
MVT::i64, Custom);
7681

7782
setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
7883
setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
@@ -159,6 +164,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
159164

160165
setBooleanContents(ZeroOrOneBooleanContent);
161166
setMaxAtomicSizeInBitsSupported(64);
167+
setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);
162168

163169
// Function alignments
164170
setMinFunctionAlignment(Align(8));
@@ -316,10 +322,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
316322
report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode()));
317323
case ISD::BR_CC:
318324
return LowerBR_CC(Op, DAG);
325+
case ISD::JumpTable:
326+
return LowerJumpTable(Op, DAG);
319327
case ISD::GlobalAddress:
320328
return LowerGlobalAddress(Op, DAG);
321329
case ISD::ConstantPool:
322330
return LowerConstantPool(Op, DAG);
331+
case ISD::BlockAddress:
332+
return LowerBlockAddress(Op, DAG);
323333
case ISD::SELECT_CC:
324334
return LowerSELECT_CC(Op, DAG);
325335
case ISD::SDIV:
@@ -780,6 +790,11 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const {
780790
return LowerCall(CLI, InVals);
781791
}
782792

793+
SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
794+
JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
795+
return getAddr(N, DAG);
796+
}
797+
783798
const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
784799
switch ((BPFISD::NodeType)Opcode) {
785800
case BPFISD::FIRST_NUMBER:
@@ -811,6 +826,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty,
811826
N->getOffset(), Flags);
812827
}
813828

829+
static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty,
830+
SelectionDAG &DAG, unsigned Flags) {
831+
return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
832+
Flags);
833+
}
834+
835+
static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,
836+
SelectionDAG &DAG, unsigned Flags) {
837+
return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
838+
}
839+
814840
template <class NodeTy>
815841
SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
816842
unsigned Flags) const {
@@ -837,6 +863,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op,
837863
return getAddr(N, DAG);
838864
}
839865

866+
SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op,
867+
SelectionDAG &DAG) const {
868+
BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
869+
return getAddr(N, DAG);
870+
}
871+
840872
unsigned
841873
BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB,
842874
unsigned Reg, bool isSigned) const {

llvm/lib/Target/BPF/BPFISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class BPFTargetLowering : public TargetLowering {
8181
SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
8282
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
8383
SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const;
84+
SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
85+
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
8486

8587
template <class NodeTy>
8688
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;

llvm/lib/Target/BPF/BPFInstrInfo.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
181181
if (!isUnpredicatedTerminator(*I))
182182
break;
183183

184+
// If a JX insn, we're done.
185+
if (I->getOpcode() == BPF::JX)
186+
break;
187+
184188
// A terminator that isn't a branch can't easily be handled
185189
// by this analysis.
186190
if (!I->isBranch())
@@ -259,3 +263,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,
259263

260264
return Count;
261265
}
266+
267+
int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
268+
// The pattern looks like:
269+
// %0 = LD_imm64 %jump-table.0 ; load jump-table address
270+
// %1 = ADD_rr %0, $another_reg ; address + offset
271+
// %2 = LDD %1, 0 ; load the actual label
272+
// JX %2
273+
const MachineFunction &MF = *MI.getParent()->getParent();
274+
const MachineRegisterInfo &MRI = MF.getRegInfo();
275+
276+
Register Reg = MI.getOperand(0).getReg();
277+
if (!Reg.isVirtual())
278+
return -1;
279+
MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg);
280+
if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD)
281+
return -1;
282+
283+
Reg = Ldd->getOperand(1).getReg();
284+
if (!Reg.isVirtual())
285+
return -1;
286+
MachineInstr *Add = MRI.getUniqueVRegDef(Reg);
287+
if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr)
288+
return -1;
289+
290+
Reg = Add->getOperand(1).getReg();
291+
if (!Reg.isVirtual())
292+
return -1;
293+
MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg);
294+
if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64)
295+
return -1;
296+
297+
const MachineOperand &MO = LDimm64->getOperand(1);
298+
if (!MO.isJTI())
299+
return -1;
300+
301+
return MO.getIndex();
302+
}

llvm/lib/Target/BPF/BPFInstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
5858
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
5959
const DebugLoc &DL,
6060
int *BytesAdded = nullptr) const override;
61+
62+
int getJumpTableIndex(const MachineInstr &MI) const override;
63+
6164
private:
6265
void expandMEMCPY(MachineBasicBlock::iterator) const;
6366

llvm/lib/Target/BPF/BPFInstrInfo.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ class JMP_RI<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
216216
let BPFClass = BPF_JMP;
217217
}
218218

219+
class JMP_IND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220+
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
221+
(outs),
222+
(ins GPR:$dst),
223+
!strconcat(OpcodeStr, " $dst"),
224+
Pattern> {
225+
bits<4> dst;
226+
227+
let Inst{51-48} = dst;
228+
let BPFClass = BPF_JMP;
229+
}
230+
219231
class JMP_JCOND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
220232
: TYPE_ALU_JMP<Opc.Value, BPF_K.Value,
221233
(outs),
@@ -281,6 +293,10 @@ defm JSLT : J<BPF_JSLT, "s<", BPF_CC_LT, BPF_CC_LT_32>;
281293
defm JSLE : J<BPF_JSLE, "s<=", BPF_CC_LE, BPF_CC_LE_32>;
282294
defm JSET : J<BPF_JSET, "&", NoCond, NoCond>;
283295
def JCOND : JMP_JCOND<BPF_JCOND, "may_goto", []>;
296+
297+
let isIndirectBranch = 1 in {
298+
def JX : JMP_IND<BPF_JA, "gotox", [(brind i64:$dst)]>;
299+
}
284300
}
285301

286302
// ALU instructions
@@ -851,6 +867,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in {
851867
// load 64-bit global addr into register
852868
def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>;
853869
def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>;
870+
def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>;
871+
def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>;
854872

855873
// 0xffffFFFF doesn't fit into simm32, optimize common case
856874
def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)),

llvm/lib/Target/BPF/BPFMCInstLower.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const {
7777
case MachineOperand::MO_ConstantPoolIndex:
7878
MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex()));
7979
break;
80+
case MachineOperand::MO_JumpTableIndex:
81+
MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex()));
82+
break;
83+
case MachineOperand::MO_BlockAddress:
84+
MCOp = LowerSymbolOperand(MO, Printer.GetBlockAddressSymbol(MO.getBlockAddress()));
85+
break;
8086
}
8187

8288
OutMI.addOperand(MCOp);

0 commit comments

Comments
 (0)