Skip to content

Commit a7fc4de

Browse files
committed
[MLIR] Fix import of invokes with mismatched variadic types
This resolves the same issue addressed in llvm#124286 for call operations and refactors the common conversion code for both call and invoke instructions.
1 parent afbce5d commit a7fc4de

File tree

3 files changed

+57
-32
lines changed

3 files changed

+57
-32
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,15 @@ class ModuleImport {
323323
/// MLIR as a value.
324324
FailureOr<SmallVector<Value>>
325325
convertCallOperands(llvm::CallBase *callInst, bool allowInlineAsm = false);
326-
/// Converts the parameter attributes attached to `func` and adds them to the
327-
/// `funcOp`.
326+
/// Converts the callee's function type. For direct calls, it converts the
327+
/// actual function type, which may differ from the called operand type in
328+
/// variadic functions. For indirect calls, it converts the function type
329+
/// associated with the call instruction.
330+
LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
331+
/// Returns the callee name, or an empty symbol if the call is not direct.
332+
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
333+
/// Converts the parameter attributes attached to `func` and adds them to
334+
/// the `funcOp`.
328335
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
329336
OpBuilder &builder);
330337
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,26 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
15191519
return operands;
15201520
}
15211521

1522+
LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
1523+
llvm::Value *calledOperand = callInst->getCalledOperand();
1524+
Type converted = [&] {
1525+
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1526+
return convertType(callee->getFunctionType());
1527+
return convertType(callInst->getFunctionType());
1528+
}();
1529+
1530+
if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
1531+
return funcTy;
1532+
return {};
1533+
}
1534+
1535+
FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
1536+
llvm::Value *calledOperand = callInst->getCalledOperand();
1537+
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1538+
return SymbolRefAttr::get(context, callee->getName());
1539+
return {};
1540+
}
1541+
15221542
LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
15231543
if (succeeded(iface.convertIntrinsic(builder, inst, *this)))
15241544
return success();
@@ -1623,25 +1643,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16231643
else
16241644
mapNoResultOp(inst, callOp);
16251645
} else {
1626-
auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1627-
// Retrieve the real function type. For direct calls, use the callee's
1628-
// function type, as it may differ from the operand type in the case of
1629-
// variadic functions. For indirect calls, use the call function type.
1630-
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1631-
return convertType(callee->getFunctionType());
1632-
return convertType(callInst->getFunctionType());
1633-
}());
1634-
1646+
auto funcTy = convertFunctionType(callInst);
16351647
if (!funcTy)
16361648
return failure();
16371649

1638-
auto callOp = [&]() -> CallOp {
1639-
if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1640-
auto name = SymbolRefAttr::get(context, callee->getName());
1641-
return builder.create<CallOp>(loc, funcTy, name, *operands);
1642-
}
1643-
return builder.create<CallOp>(loc, funcTy, *operands);
1644-
}();
1650+
FlatSymbolRefAttr calleeName = convertCalleeName(callInst);
1651+
auto callOp = builder.create<CallOp>(loc, funcTy, calleeName, *operands);
16451652

16461653
// Handle function attributes.
16471654
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
@@ -1725,26 +1732,19 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17251732
unwindArgs)))
17261733
return failure();
17271734

1728-
auto funcTy =
1729-
dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
1735+
auto funcTy = convertFunctionType(invokeInst);
17301736
if (!funcTy)
17311737
return failure();
17321738

1739+
FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
1740+
17331741
// Create the invoke operation. Normal destination block arguments will be
17341742
// added later on to handle the case in which the operation result is
17351743
// included in this list.
1736-
InvokeOp invokeOp;
1737-
if (llvm::Function *callee = invokeInst->getCalledFunction()) {
1738-
invokeOp = builder.create<InvokeOp>(
1739-
loc, funcTy,
1740-
SymbolRefAttr::get(builder.getContext(), callee->getName()),
1741-
*operands, directNormalDest, ValueRange(),
1742-
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1743-
} else {
1744-
invokeOp = builder.create<InvokeOp>(
1745-
loc, funcTy, /*callee=*/nullptr, *operands, directNormalDest,
1746-
ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1747-
}
1744+
auto invokeOp = builder.create<InvokeOp>(
1745+
loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
1746+
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1747+
17481748
invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
17491749
if (!invokeInst->getType()->isVoidTy())
17501750
mapValue(inst, invokeOp.getResults().front());

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,3 +702,21 @@ define void @fence() {
702702
fence syncscope("") seq_cst
703703
ret void
704704
}
705+
706+
; // -----
707+
708+
; CHECK-LABEL: @f
709+
define void @f() personality ptr @__gxx_personality_v0 {
710+
entry:
711+
; CHECK: llvm.invoke @g() to ^bb1 unwind ^bb2 vararg(!llvm.func<void (...)>) : () -> ()
712+
invoke void @g() to label %bb1 unwind label %bb2
713+
bb1:
714+
ret void
715+
bb2:
716+
%0 = landingpad i32 cleanup
717+
unreachable
718+
}
719+
720+
declare void @g(...)
721+
722+
declare i32 @__gxx_personality_v0(...)

0 commit comments

Comments
 (0)