Skip to content

Commit 646b231

Browse files
committed
Add ops and lowering for variadic function calls
1 parent 68f981c commit 646b231

File tree

6 files changed

+75
-12
lines changed

6 files changed

+75
-12
lines changed

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,25 @@ def Cxx_ClassType : Cxx_Type<"Class", "class", [MutableType]> {
9494

9595
}
9696

97+
def Cxx_FunctionType : Cxx_Type<"Function", "function"> {
98+
let parameters = (ins
99+
ArrayRefParameter<"mlir::Type">:$inputs
100+
, ArrayRefParameter<"mlir::Type">:$results
101+
, "bool":$variadic
102+
);
103+
104+
let assemblyFormat = "`<` $inputs `,` $results `,` $variadic `>`";
105+
106+
let extraClassDeclaration = [{
107+
auto clone(mlir::TypeRange inputs, mlir::TypeRange results) const -> FunctionType;
108+
}];
109+
}
110+
97111
// ops
98112

99113
def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
100114
let arguments = (ins SymbolNameAttr:$sym_name,
101-
TypeAttrOf<FunctionType>:$function_type,
115+
TypeAttrOf<Cxx_FunctionType>:$function_type,
102116
OptionalAttr<DictArrayAttr>:$arg_attrs,
103117
OptionalAttr<DictArrayAttr>:$res_attrs);
104118

@@ -134,10 +148,9 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
134148

135149
def Cxx_CallOp : Cxx_Op<"call"> {
136150
let arguments = (ins
137-
FlatSymbolRefAttr:$callee,
138-
Variadic<AnyType>:$inputs,
139-
OptionalAttr<DictArrayAttr>:$arg_attrs,
140-
OptionalAttr<DictArrayAttr>:$res_attrs
151+
FlatSymbolRefAttr:$callee
152+
, Variadic<AnyType>:$inputs
153+
, OptionalAttr<TypeAttrOf<Cxx_FunctionType>>:$var_callee_type
141154
);
142155

143156
let results = (outs Optional<AnyType>:$result);

src/mlir/cxx/mlir/codegen.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
117117
resultTypes.push_back(convertType(returnType));
118118
}
119119

120-
auto funcType = builder_.getFunctionType(inputTypes, resultTypes);
120+
auto funcType =
121+
mlir::cxx::FunctionType::get(builder_.getContext(), inputTypes,
122+
resultTypes, functionType->isVariadic());
121123

122124
std::string name;
123125

src/mlir/cxx/mlir/codegen_expressions.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,12 @@ auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast)
522522
}
523523

524524
auto op = gen.builder_.create<mlir::cxx::CallOp>(
525-
loc, resultTypes, funcOp.getSymName(), arguments, mlir::ArrayAttr{},
526-
mlir::ArrayAttr{});
525+
loc, resultTypes, funcOp.getSymName(), arguments, mlir::TypeAttr{});
526+
527+
if (functionType->isVariadic()) {
528+
op.setVarCalleeType(
529+
cast<mlir::cxx::FunctionType>(gen.convertType(functionType)));
530+
}
527531

528532
return ExpressionResult{op.getResult()};
529533
};

src/mlir/cxx/mlir/convert_type.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,16 @@ auto Codegen::ConvertType::operator()(const RvalueReferenceType* type)
258258
}
259259

260260
auto Codegen::ConvertType::operator()(const FunctionType* type) -> mlir::Type {
261-
return getExprType();
261+
mlir::SmallVector<mlir::Type> inputs;
262+
for (auto argType : type->parameterTypes()) {
263+
inputs.push_back(gen.convertType(argType));
264+
}
265+
mlir::SmallVector<mlir::Type> results;
266+
if (!control()->is_void(type->returnType())) {
267+
results.push_back(gen.convertType(type->returnType()));
268+
}
269+
return gen.builder_.getType<mlir::cxx::FunctionType>(inputs, results,
270+
type->isVariadic());
262271
}
263272

264273
auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {

src/mlir/cxx/mlir/cxx_dialect.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ void CxxDialect::initialize() {
125125
}
126126

127127
void FuncOp::print(OpAsmPrinter &p) {
128+
const auto isVariadic = getFunctionType().getVariadic();
128129
function_interface_impl::printFunctionOp(
129-
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
130-
getArgAttrsAttrName(), getResAttrsAttrName());
130+
p, *this, isVariadic, getFunctionTypeAttrName(), getArgAttrsAttrName(),
131+
getResAttrsAttrName());
131132
}
132133

133134
auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult {
@@ -160,6 +161,12 @@ auto StoreOp::verify() -> LogicalResult {
160161
return success();
161162
}
162163

164+
auto FunctionType::clone(TypeRange inputs, TypeRange results) const
165+
-> FunctionType {
166+
return get(getContext(), llvm::to_vector(inputs), llvm::to_vector(results),
167+
getVariadic());
168+
}
169+
163170
auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType {
164171
return Base::get(context, name);
165172
}

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
8080
? LLVM::LLVMVoidType::get(getContext())
8181
: resultTypes.front();
8282

83-
const auto isVarArg = false;
83+
const auto isVarArg = funcType.getVariadic();
8484

8585
auto llvmFuncType =
8686
LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg);
@@ -159,6 +159,12 @@ class CallOpLowering : public OpConversionPattern<cxx::CallOp> {
159159
auto llvmCallOp = rewriter.create<LLVM::CallOp>(
160160
op.getLoc(), resultTypes, adaptor.getCallee(), adaptor.getInputs());
161161

162+
if (op.getVarCalleeType().has_value()) {
163+
auto varCalleeType =
164+
typeConverter->convertType(op.getVarCalleeType().value());
165+
llvmCallOp.setVarCalleeType(cast<LLVM::LLVMFunctionType>(varCalleeType));
166+
}
167+
162168
rewriter.replaceOp(op, llvmCallOp);
163169
return success();
164170
}
@@ -1251,6 +1257,28 @@ void CxxToLLVMLoweringPass::runOnOperation() {
12511257
return LLVM::LLVMArrayType::get(elementType, size);
12521258
});
12531259

1260+
typeConverter.addConversion([&](cxx::FunctionType type) -> Type {
1261+
SmallVector<Type> inputs;
1262+
for (auto argType : type.getInputs()) {
1263+
auto convertedType = typeConverter.convertType(argType);
1264+
inputs.push_back(convertedType);
1265+
}
1266+
SmallVector<Type> results;
1267+
for (auto resultType : type.getResults()) {
1268+
auto convertedType = typeConverter.convertType(resultType);
1269+
results.push_back(convertedType);
1270+
}
1271+
if (results.size() > 1) {
1272+
return {};
1273+
}
1274+
if (results.empty()) {
1275+
results.push_back(LLVM::LLVMVoidType::get(type.getContext()));
1276+
}
1277+
auto context = type.getContext();
1278+
return LLVM::LLVMFunctionType::get(context, results.front(), inputs,
1279+
type.getVariadic());
1280+
});
1281+
12541282
DenseMap<cxx::ClassType, Type> convertedClassTypes;
12551283
typeConverter.addConversion([&](cxx::ClassType type) -> Type {
12561284
if (auto it = convertedClassTypes.find(type);

0 commit comments

Comments
 (0)