From 10135855aeb9bfca008dd1dfb62507f75ef52785 Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Tue, 12 Aug 2025 21:06:59 +0200 Subject: [PATCH] Add ops and lowering for variadic function calls --- src/mlir/cxx/mlir/CxxOps.td | 23 +++++++++++---- src/mlir/cxx/mlir/codegen.cc | 4 ++- src/mlir/cxx/mlir/codegen_expressions.cc | 8 ++++-- src/mlir/cxx/mlir/convert_type.cc | 11 ++++++- src/mlir/cxx/mlir/cxx_dialect.cc | 11 +++++-- src/mlir/cxx/mlir/cxx_dialect_conversions.cc | 30 +++++++++++++++++++- 6 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 578df640..bd0846cb 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -94,11 +94,25 @@ def Cxx_ClassType : Cxx_Type<"Class", "class", [MutableType]> { } +def Cxx_FunctionType : Cxx_Type<"Function", "function"> { + let parameters = (ins + ArrayRefParameter<"mlir::Type">:$inputs + , ArrayRefParameter<"mlir::Type">:$results + , "bool":$variadic + ); + + let assemblyFormat = "`<` $inputs `,` $results `,` $variadic `>`"; + + let extraClassDeclaration = [{ + auto clone(mlir::TypeRange inputs, mlir::TypeRange results) const -> FunctionType; + }]; +} + // ops def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> { let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type, + TypeAttrOf:$function_type, OptionalAttr:$arg_attrs, OptionalAttr:$res_attrs); @@ -134,10 +148,9 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> { def Cxx_CallOp : Cxx_Op<"call"> { let arguments = (ins - FlatSymbolRefAttr:$callee, - Variadic:$inputs, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs + FlatSymbolRefAttr:$callee + , Variadic:$inputs + , OptionalAttr>:$var_callee_type ); let results = (outs Optional:$result); diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index a050510e..fd29345d 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -117,7 +117,9 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol) resultTypes.push_back(convertType(returnType)); } - auto funcType = builder_.getFunctionType(inputTypes, resultTypes); + auto funcType = + mlir::cxx::FunctionType::get(builder_.getContext(), inputTypes, + resultTypes, functionType->isVariadic()); std::string name; diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index aaa8b57f..7f00e08e 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -522,8 +522,12 @@ auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast) } auto op = gen.builder_.create( - loc, resultTypes, funcOp.getSymName(), arguments, mlir::ArrayAttr{}, - mlir::ArrayAttr{}); + loc, resultTypes, funcOp.getSymName(), arguments, mlir::TypeAttr{}); + + if (functionType->isVariadic()) { + op.setVarCalleeType( + cast(gen.convertType(functionType))); + } return ExpressionResult{op.getResult()}; }; diff --git a/src/mlir/cxx/mlir/convert_type.cc b/src/mlir/cxx/mlir/convert_type.cc index 0306ea23..392eb2e8 100644 --- a/src/mlir/cxx/mlir/convert_type.cc +++ b/src/mlir/cxx/mlir/convert_type.cc @@ -258,7 +258,16 @@ auto Codegen::ConvertType::operator()(const RvalueReferenceType* type) } auto Codegen::ConvertType::operator()(const FunctionType* type) -> mlir::Type { - return getExprType(); + mlir::SmallVector inputs; + for (auto argType : type->parameterTypes()) { + inputs.push_back(gen.convertType(argType)); + } + mlir::SmallVector results; + if (!control()->is_void(type->returnType())) { + results.push_back(gen.convertType(type->returnType())); + } + return gen.builder_.getType(inputs, results, + type->isVariadic()); } auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type { diff --git a/src/mlir/cxx/mlir/cxx_dialect.cc b/src/mlir/cxx/mlir/cxx_dialect.cc index e0b4b026..7ed2bf9c 100644 --- a/src/mlir/cxx/mlir/cxx_dialect.cc +++ b/src/mlir/cxx/mlir/cxx_dialect.cc @@ -125,9 +125,10 @@ void CxxDialect::initialize() { } void FuncOp::print(OpAsmPrinter &p) { + const auto isVariadic = getFunctionType().getVariadic(); function_interface_impl::printFunctionOp( - p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); + p, *this, isVariadic, getFunctionTypeAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()); } auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult { @@ -160,6 +161,12 @@ auto StoreOp::verify() -> LogicalResult { return success(); } +auto FunctionType::clone(TypeRange inputs, TypeRange results) const + -> FunctionType { + return get(getContext(), llvm::to_vector(inputs), llvm::to_vector(results), + getVariadic()); +} + auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType { return Base::get(context, name); } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 61512b80..c738f2e6 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -80,7 +80,7 @@ class FuncOpLowering : public OpConversionPattern { ? LLVM::LLVMVoidType::get(getContext()) : resultTypes.front(); - const auto isVarArg = false; + const auto isVarArg = funcType.getVariadic(); auto llvmFuncType = LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg); @@ -159,6 +159,12 @@ class CallOpLowering : public OpConversionPattern { auto llvmCallOp = rewriter.create( op.getLoc(), resultTypes, adaptor.getCallee(), adaptor.getInputs()); + if (op.getVarCalleeType().has_value()) { + auto varCalleeType = + typeConverter->convertType(op.getVarCalleeType().value()); + llvmCallOp.setVarCalleeType(cast(varCalleeType)); + } + rewriter.replaceOp(op, llvmCallOp); return success(); } @@ -1251,6 +1257,28 @@ void CxxToLLVMLoweringPass::runOnOperation() { return LLVM::LLVMArrayType::get(elementType, size); }); + typeConverter.addConversion([&](cxx::FunctionType type) -> Type { + SmallVector inputs; + for (auto argType : type.getInputs()) { + auto convertedType = typeConverter.convertType(argType); + inputs.push_back(convertedType); + } + SmallVector results; + for (auto resultType : type.getResults()) { + auto convertedType = typeConverter.convertType(resultType); + results.push_back(convertedType); + } + if (results.size() > 1) { + return {}; + } + if (results.empty()) { + results.push_back(LLVM::LLVMVoidType::get(type.getContext())); + } + auto context = type.getContext(); + return LLVM::LLVMFunctionType::get(context, results.front(), inputs, + type.getVariadic()); + }); + DenseMap convertedClassTypes; typeConverter.addConversion([&](cxx::ClassType type) -> Type { if (auto it = convertedClassTypes.find(type);