Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1077,4 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
/// Folded into from LLVM::ZeroOp.
def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">;

//===----------------------------------------------------------------------===//
// TailCallKindAttr
//===----------------------------------------------------------------------===//

def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> {
let parameters = (ins "TailCallKind":$TailCallKind);
let assemblyFormat = "`<` $TailCallKind `>`";
}

#endif // LLVMIR_ATTRDEFS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class TBAANodeAttr : public Attribute {
// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
// --gen-attr-* and --gen-attrdef-*.
using cconv::CConv;
using tailcallkind::TailCallKind;
using linkage::Linkage;
} // namespace LLVM
} // namespace mlir
Expand Down
29 changes: 29 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,35 @@ def CConv : DialectAttr<
"::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
}

//===----------------------------------------------------------------------===//
// TailCallKind
//===----------------------------------------------------------------------===//

def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>;
def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>;
def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>;
def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>;

def TailCallKindEnum : LLVM_EnumAttr<
"TailCallKind",
"::llvm::CallInst::TailCallKind",
"Tail Call Kind",
[TailCallKindNone, TailCallKindNoTailCall,
TailCallKindMustTail, TailCallKindTail]> {
let cppNamespace = "::mlir::LLVM::tailcallkind";
}

def TailCallKind : DialectAttr<
LLVM_Dialect,
CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">,
"LLVM Calling Convention specification"> {
let storageType = "::mlir::LLVM::TailCallKindAttr";
let returnType = "::mlir::LLVM::tailcallkind::TailCallKind";
let convertFromStorage = "$_self.getTailCallKind()";
let constBuilderCall =
"::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)";
}

//===----------------------------------------------------------------------===//
// DIEmissionKind
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
Expand Down
24 changes: 18 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::cconv::getMaxEnumValForCConv;
using mlir::LLVM::linkage::getMaxEnumValForLinkage;
using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;

#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"

Expand Down Expand Up @@ -197,6 +198,7 @@ struct EnumTraits {};
REGISTER_ENUM_TYPE(Linkage);
REGISTER_ENUM_TYPE(UnnamedAddr);
REGISTER_ENUM_TYPE(CConv);
REGISTER_ENUM_TYPE(TailCallKind);
REGISTER_ENUM_TYPE(Visibility);
} // namespace

Expand Down Expand Up @@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
build(builder, state, results,
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
Expand All @@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
/*access_groups=*/nullptr,
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}

Expand All @@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
Expand All @@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
Expand Down Expand Up @@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) {
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';

if(getTailCallKind() != LLVM::TailCallKind::None)
p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';

// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
if (isDirect)
Expand All @@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) {
p << " vararg(" << calleeType << ")";

p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getCConvAttrName(), "callee", "callee_type"});
{getCConvAttrName(), "callee", "callee_type",
getTailCallKindAttrName()});

p << " : ";
if (!isDirect)
Expand Down Expand Up @@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr(
return success();
}

// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-arg-func-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
Expand All @@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));

result.addAttribute(
getTailCallKindAttrName(result.name),
TailCallKindAttr::get(parser.getContext(),
parseOptionalLLVMKeyword<TailCallKind>(
parser, result, LLVM::TailCallKind::None)));

// Parse a function pointer for indirect calls.
if (parseOptionalCallFuncPtr(parser, operands))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front());
}
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
moduleTranslation.setTBAAMetadata(callOp, call);
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
callOp = builder.create<CallOp>(loc, funcTy, operands);
}
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);
if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
llvm.return
}

// CHECK: llvm.func @tail_call_target() -> i32
llvm.func @tail_call_target() -> i32

// CHECK-LABEL: @test_none
llvm.func @test_none() -> i32 {
// CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
%0 = llvm.call none @tail_call_target() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_default
llvm.func @test_default() -> i32 {
// CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
%0 = llvm.call @tail_call_target() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_musttail
llvm.func @test_musttail() -> i32 {
// CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32
%0 = llvm.call musttail @tail_call_target() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_tail
llvm.func @test_tail() -> i32 {
// CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32
%0 = llvm.call tail @tail_call_target() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_notail
llvm.func @test_notail() -> i32 {
// CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32
%0 = llvm.call notail @tail_call_target() : () -> i32
llvm.return %0 : i32
}
39 changes: 39 additions & 0 deletions mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

// CHECK: declare i32 @foo()
llvm.func @foo() -> i32

// CHECK-LABEL: @test_none
llvm.func @test_none() -> i32 {
// CHECK-NEXT: call i32 @foo()
%0 = llvm.call none @foo() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_default
llvm.func @test_default() -> i32 {
// CHECK-NEXT: call i32 @foo()
%0 = llvm.call @foo() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_musttail
llvm.func @test_musttail() -> i32 {
// CHECK-NEXT: musttail call i32 @foo()
%0 = llvm.call musttail @foo() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_tail
llvm.func @test_tail() -> i32 {
// CHECK-NEXT: tail call i32 @foo()
%0 = llvm.call tail @foo() : () -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: @test_notail
llvm.func @test_notail() -> i32 {
// CHECK-NEXT: notail call i32 @foo()
%0 = llvm.call notail @foo() : () -> i32
llvm.return %0 : i32
}
35 changes: 35 additions & 0 deletions mlir/test/Target/LLVMIR/Import/tail-kind.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK: llvm.func @tailkind()
declare void @tailkind()

; CHECK-LABEL: @call_tailkind
define void @call_tailkind() {
; CHECK: llvm.call musttail @tailkind()
musttail call void @tailkind()
ret void
}

; // -----

; CHECK: llvm.func @tailkind()
declare void @tailkind()

; CHECK-LABEL: @call_tailkind
define void @call_tailkind() {
; CHECK: llvm.call tail @tailkind()
tail call void @tailkind()
ret void
}

; // -----

; CHECK: llvm.func @tailkind()
declare void @tailkind()

; CHECK-LABEL: @call_tailkind
define void @call_tailkind() {
; CHECK: llvm.call notail @tailkind()
notail call void @tailkind()
ret void
}