Skip to content

Commit 0124fef

Browse files
committed
[MLIR][LLVMIR] Allow calling and invoking aliases
1 parent 9f535c4 commit 0124fef

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,16 @@ static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
138138
Operation *op,
139139
SymbolTableCollection &symbolTable) {
140140
StringRef name = symbol.getValue();
141-
auto func =
142-
symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
143-
if (!func)
144-
return op->emitOpError("'")
145-
<< name << "' does not reference a valid LLVM function";
146-
if (func.isExternal())
147-
return op->emitOpError("'") << name << "' does not have a definition";
148-
return success();
141+
auto *symOp = symbolTable.lookupNearestSymbolFrom(op, symbol.getAttr());
142+
if (auto func = dyn_cast<LLVMFuncOp>(symOp)) {
143+
if (func.isExternal())
144+
return op->emitOpError("'") << name << "' does not have a definition";
145+
return success();
146+
}
147+
if (auto alias = dyn_cast<AliasOp>(symOp))
148+
return success();
149+
return op->emitOpError("'")
150+
<< name << "' does not reference a valid LLVM function";
149151
}
150152

151153
/// Returns a boolean type that has the same shape as `type`. It supports both
@@ -1241,14 +1243,16 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
12411243
return emitOpError()
12421244
<< "'" << calleeName.getValue()
12431245
<< "' does not reference a symbol in the current scope";
1244-
auto fn = dyn_cast<LLVMFuncOp>(callee);
1245-
if (!fn)
1246+
if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
1247+
if (failed(verifyCallOpDebugInfo(*this, fn)))
1248+
return failure();
1249+
fnType = fn.getFunctionType();
1250+
} else if (auto alias = dyn_cast<AliasOp>(callee)) {
1251+
fnType = alias.getAliasType();
1252+
} else {
12461253
return emitOpError() << "'" << calleeName.getValue()
12471254
<< "' does not reference a valid LLVM function";
1248-
1249-
if (failed(verifyCallOpDebugInfo(*this, fn)))
1250-
return failure();
1251-
fnType = fn.getFunctionType();
1255+
}
12521256
}
12531257

12541258
LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
302302
ArrayRef<llvm::Value *> operandsRef(operands);
303303
llvm::CallInst *call;
304304
if (auto attr = callOp.getCalleeAttr()) {
305+
llvm::GlobalValue *callee;
306+
llvm::FunctionType *ftype;
307+
if (auto *func = moduleTranslation.lookupFunction(attr.getValue())) {
308+
ftype = func->getFunctionType();
309+
callee = func;
310+
} else {
311+
auto &st = moduleTranslation.symbolTable().getSymbolTable(callOp);
312+
auto alias = st.lookup<AliasOp>(attr.getValue());
313+
callee = moduleTranslation.lookupAlias(alias);
314+
ftype = llvm::cast<llvm::FunctionType>(moduleTranslation.convertType(alias.getAliasType()));
315+
}
305316
call =
306-
builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
307-
operandsRef, opBundles);
317+
builder.CreateCall(ftype, callee, operandsRef, opBundles);
308318
} else {
309319
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
310320
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
@@ -415,8 +425,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
415425
ArrayRef<llvm::Value *> operandsRef(operands);
416426
llvm::InvokeInst *result;
417427
if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
428+
llvm::GlobalValue *callee;
429+
llvm::FunctionType *ftype;
430+
if (auto *func = moduleTranslation.lookupFunction(attr.getValue())) {
431+
ftype = func->getFunctionType();
432+
callee = func;
433+
} else {
434+
auto &st = moduleTranslation.symbolTable().getSymbolTable(invOp);
435+
auto alias = st.lookup<AliasOp>(attr.getValue());
436+
callee = moduleTranslation.lookupAlias(alias);
437+
ftype = llvm::cast<llvm::FunctionType>(moduleTranslation.convertType(alias.getAliasType()));
438+
}
418439
result = builder.CreateInvoke(
419-
moduleTranslation.lookupFunction(attr.getValue()),
440+
ftype, callee,
420441
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
421442
moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
422443
opBundles);

mlir/test/Dialect/LLVMIR/alias.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ llvm.func internal @callee() -> !llvm.ptr attributes {dso_local} {
55
llvm.return %0 : !llvm.ptr
66
}
77

8-
llvm.mlir.alias external @foo_alias : !llvm.ptr {
8+
llvm.mlir.alias external @foo_alias : !llvm.func<ptr ()> {
99
%0 = llvm.mlir.addressof @callee : !llvm.ptr
1010
llvm.return %0 : !llvm.ptr
1111
}
@@ -15,6 +15,24 @@ llvm.mlir.alias external @_ZTV1D : !llvm.struct<(array<3 x ptr>)> {
1515
llvm.return %0 : !llvm.ptr
1616
}
1717

18+
llvm.func internal @caller() -> !llvm.ptr attributes {dso_local} {
19+
%0 = llvm.call @foo_alias() : () -> !llvm.ptr
20+
llvm.return %0 : !llvm.ptr
21+
}
22+
23+
llvm.func @__gxx_personality_v0(...) -> i32
24+
25+
llvm.func internal @invoker() -> !llvm.ptr attributes { dso_local, personality = @__gxx_personality_v0 } {
26+
%0 = llvm.mlir.constant("\01") : !llvm.array<1 x i8>
27+
%1 = llvm.mlir.zero : !llvm.ptr
28+
%2 = llvm.invoke @foo_alias() to ^bb1 unwind ^bb2 : () -> !llvm.ptr
29+
^bb1:
30+
llvm.return %1 : !llvm.ptr
31+
^bb2:
32+
%3 = llvm.landingpad cleanup (catch %1 : !llvm.ptr) (catch %1 : !llvm.ptr) (filter %0 : !llvm.array<1 x i8>) : !llvm.struct<(ptr, i32)>
33+
llvm.return %1 : !llvm.ptr
34+
}
35+
1836
// CHECK: llvm.mlir.alias external @foo_alias : !llvm.ptr {
1937
// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
2038
// CHECK: llvm.return %[[ADDR]] : !llvm.ptr

0 commit comments

Comments
 (0)