Skip to content

Commit 8a71608

Browse files
committed
[MLIR][LLVMIR] Allow calling and invoking aliases
1 parent a49705f commit 8a71608

File tree

3 files changed

+60
-20
lines changed

3 files changed

+60
-20
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
6565
Operation *op,
6666
SymbolTableCollection &symbolTable) {
6767
StringRef name = symbol.getValue();
68-
auto func =
69-
symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
70-
if (!func)
71-
return op->emitOpError("'")
72-
<< name << "' does not reference a valid LLVM function";
73-
if (func.isExternal())
74-
return op->emitOpError("'") << name << "' does not have a definition";
75-
return success();
68+
auto *symOp = symbolTable.lookupNearestSymbolFrom(op, symbol.getAttr());
69+
if (auto func = dyn_cast<LLVMFuncOp>(symOp)) {
70+
if (func.isExternal())
71+
return op->emitOpError("'") << name << "' does not have a definition";
72+
return success();
73+
}
74+
if (auto alias = dyn_cast<AliasOp>(symOp))
75+
return success();
76+
return op->emitOpError("'")
77+
<< name << "' does not reference a valid LLVM function";
7678
}
7779

7880
/// Returns a boolean type that has the same shape as `type`. It supports both
@@ -1184,10 +1186,12 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
11841186
fnType = fn.getFunctionType();
11851187
} else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
11861188
fnType = ifunc.getIFuncType();
1189+
} else if (auto alias = dyn_cast<AliasOp>(callee)) {
1190+
fnType = alias.getAliasType();
11871191
} else {
1188-
return emitOpError()
1189-
<< "'" << calleeName.getValue()
1190-
<< "' does not reference a valid LLVM function or IFunc";
1192+
return emitOpError()
1193+
<< "'" << calleeName.getValue()
1194+
<< "' does not reference a valid LLVM function or IFunc";
11911195
}
11921196
}
11931197

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -426,13 +426,20 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
426426
moduleTranslation.lookupFunction(attr.getValue())) {
427427
call = builder.CreateCall(function, operandsRef, opBundles);
428428
} else {
429-
Operation *moduleOp = parentLLVMModule(&opInst);
430-
Operation *ifuncOp =
431-
moduleTranslation.symbolTable().lookupSymbolIn(moduleOp, attr);
432-
llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp);
433-
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
434-
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
435-
call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles);
429+
auto &st = moduleTranslation.symbolTable().getSymbolTable(callOp);
430+
auto alias = st.lookup<AliasOp>(attr.getValue());
431+
if (alias) {
432+
llvm::GlobalValue *callee = moduleTranslation.lookupAlias(alias);
433+
llvm::FunctionType *ftype = llvm::cast<llvm::FunctionType>(moduleTranslation.convertType(alias.getAliasType()));
434+
call = builder.CreateCall(ftype, callee, operandsRef, opBundles);
435+
} else {
436+
Operation *moduleOp = parentLLVMModule(&opInst);
437+
Operation *ifuncOp = st.lookupSymbolIn(moduleOp, attr);
438+
llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp);
439+
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
440+
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
441+
call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles);
442+
}
436443
}
437444
} else {
438445
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
@@ -554,8 +561,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
554561
ArrayRef<llvm::Value *> operandsRef(operands);
555562
llvm::InvokeInst *result;
556563
if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
564+
llvm::GlobalValue *callee;
565+
llvm::FunctionType *ftype;
566+
if (auto *func = moduleTranslation.lookupFunction(attr.getValue())) {
567+
ftype = func->getFunctionType();
568+
callee = func;
569+
} else {
570+
auto &st = moduleTranslation.symbolTable().getSymbolTable(invOp);
571+
auto alias = st.lookup<AliasOp>(attr.getValue());
572+
callee = moduleTranslation.lookupAlias(alias);
573+
ftype = llvm::cast<llvm::FunctionType>(moduleTranslation.convertType(alias.getAliasType()));
574+
}
557575
result = builder.CreateInvoke(
558-
moduleTranslation.lookupFunction(attr.getValue()),
576+
ftype, callee,
559577
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
560578
moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
561579
opBundles);

mlir/test/Dialect/LLVMIR/alias.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ llvm.func internal @callee() -> !llvm.ptr attributes {dso_local} {
55
llvm.return %0 : !llvm.ptr
66
}
77

8-
llvm.mlir.alias external @foo_alias : !llvm.ptr {
8+
llvm.mlir.alias external @foo_alias : !llvm.func<ptr ()> {
99
%0 = llvm.mlir.addressof @callee : !llvm.ptr
1010
llvm.return %0 : !llvm.ptr
1111
}
@@ -15,6 +15,24 @@ llvm.mlir.alias external @_ZTV1D : !llvm.struct<(array<3 x ptr>)> {
1515
llvm.return %0 : !llvm.ptr
1616
}
1717

18+
llvm.func internal @caller() -> !llvm.ptr attributes {dso_local} {
19+
%0 = llvm.call @foo_alias() : () -> !llvm.ptr
20+
llvm.return %0 : !llvm.ptr
21+
}
22+
23+
llvm.func @__gxx_personality_v0(...) -> i32
24+
25+
llvm.func internal @invoker() -> !llvm.ptr attributes { dso_local, personality = @__gxx_personality_v0 } {
26+
%0 = llvm.mlir.constant("\01") : !llvm.array<1 x i8>
27+
%1 = llvm.mlir.zero : !llvm.ptr
28+
%2 = llvm.invoke @foo_alias() to ^bb1 unwind ^bb2 : () -> !llvm.ptr
29+
^bb1:
30+
llvm.return %1 : !llvm.ptr
31+
^bb2:
32+
%3 = llvm.landingpad cleanup (catch %1 : !llvm.ptr) (catch %1 : !llvm.ptr) (filter %0 : !llvm.array<1 x i8>) : !llvm.struct<(ptr, i32)>
33+
llvm.return %1 : !llvm.ptr
34+
}
35+
1836
// CHECK: llvm.mlir.alias external @foo_alias : !llvm.ptr {
1937
// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
2038
// CHECK: llvm.return %[[ADDR]] : !llvm.ptr

0 commit comments

Comments
 (0)