diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 0aa6ca87..765d80bc 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -21,9 +21,10 @@ include "mlir/IR/OpBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/BuiltinAttributeInterfaces.td" def Cxx_Dialect : Dialect { let name = "cxx"; @@ -103,9 +104,6 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> { let regions = (region AnyRegion:$body); - let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs)>]; - let extraClassDeclaration = [{ auto getArgumentTypes() -> ArrayRef { return getFunctionType().getInputs(); } auto getResultTypes() -> ArrayRef { return getFunctionType().getResults(); } @@ -113,7 +111,6 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> { }]; let hasCustomAssemblyFormat = 1; - let skipDefaultBuilders = 1; } def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> { @@ -128,6 +125,17 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> { let hasVerifier = 0; } +def Cxx_CallOp : Cxx_Op<"call"> { + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + + let results = (outs AnyType); +} + def Cxx_AllocaOp : Cxx_Op<"alloca"> { let arguments = (ins); diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index ddc18434..71f0a88d 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -22,8 +22,10 @@ // cxx #include +#include #include #include +#include // mlir #include @@ -79,6 +81,52 @@ auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional { return allocaOp; } +auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol) + -> mlir::cxx::FuncOp { + if (auto it = funcOps_.find(functionSymbol); it != funcOps_.end()) { + return it->second; + } + + const auto functionType = type_cast(functionSymbol->type()); + const auto returnType = functionType->returnType(); + const auto needsExitValue = !control()->is_void(returnType); + + std::vector inputTypes; + std::vector resultTypes; + + for (auto paramTy : functionType->parameterTypes()) { + inputTypes.push_back(convertType(paramTy)); + } + + if (needsExitValue) { + resultTypes.push_back(convertType(returnType)); + } + + auto funcType = builder_.getFunctionType(inputTypes, resultTypes); + + std::string name; + + if (functionSymbol->hasCLinkage()) { + name = to_string(functionSymbol->name()); + } else { + ExternalNameEncoder encoder; + name = encoder.encode(functionSymbol); + } + + const auto loc = getLocation(functionSymbol->location()); + + auto guard = mlir::OpBuilder::InsertionGuard(builder_); + + builder_.setInsertionPointToStart(module_.getBody()); + + auto func = builder_.create( + loc, name, funcType, mlir::ArrayAttr{}, mlir::ArrayAttr{}); + + funcOps_.insert_or_assign(functionSymbol, func); + + return func; +} + auto Codegen::getLocation(SourceLocation location) -> mlir::Location { auto [filename, line, column] = unit_->tokenStartPosition(location); diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index a4df8cac..fe79a204 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -259,6 +259,9 @@ class Codegen { [[nodiscard]] auto currentBlockMightHaveTerminator() -> bool; + [[nodiscard]] auto findOrCreateFunction(FunctionSymbol* functionSymbol) + -> mlir::cxx::FuncOp; + [[nodiscard]] auto findOrCreateLocal(Symbol* symbol) -> std::optional; @@ -308,6 +311,7 @@ class Codegen { mlir::cxx::AllocaOp exitValue_; std::unordered_map classNames_; std::unordered_map locals_; + std::unordered_map funcOps_; Loop loop_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index ad9748af..739a00ea 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -316,46 +316,20 @@ auto Codegen::DeclarationVisitor::operator()(OpaqueEnumDeclarationAST* ast) auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) -> DeclarationResult { auto functionSymbol = ast->symbol; + + auto func = gen.findOrCreateFunction(functionSymbol); const auto functionType = type_cast(functionSymbol->type()); const auto returnType = functionType->returnType(); const auto needsExitValue = !gen.control()->is_void(returnType); - std::vector inputTypes; - std::vector resultTypes; - - for (auto paramTy : functionType->parameterTypes()) { - inputTypes.push_back(gen.convertType(paramTy)); - } - - if (needsExitValue) { - resultTypes.push_back(gen.convertType(returnType)); - } - - auto funcType = gen.builder_.getFunctionType(inputTypes, resultTypes); - - std::vector path; - for (Symbol* symbol = ast->symbol; symbol; - symbol = symbol->enclosingSymbol()) { - if (!symbol->name()) continue; - path.push_back(to_string(symbol->name())); - } - - std::string name; + auto loc = gen.getLocation(ast->firstSourceLocation()); - if (ast->symbol->hasCLinkage()) { - name = to_string(ast->symbol->name()); - } else { - ExternalNameEncoder encoder; - name = encoder.encode(ast->symbol); + // Add the function body. + auto entryBlock = gen.builder_.createBlock(&func.getBody()); + for (const auto& input : func.getFunctionType().getInputs()) { + entryBlock->addArgument(input, loc); } - auto guard = mlir::OpBuilder::InsertionGuard(gen.builder_); - - const auto loc = gen.getLocation(ast->symbol->location()); - - std::unordered_map locals; - auto func = gen.builder_.create(loc, name, funcType); - auto entryBlock = &func.front(); auto exitBlock = gen.builder_.createBlock(&func.getBody()); mlir::cxx::AllocaOp exitValue; @@ -370,6 +344,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) exitValue = gen.builder_.create(exitValueLoc, ptrType); } + std::unordered_map locals; + // function state std::swap(gen.function_, func); std::swap(gen.exitBlock_, exitBlock); diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 49dc7c41..44903236 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -465,6 +465,43 @@ auto Codegen::ExpressionVisitor::operator()(SubscriptExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast) -> ExpressionResult { + auto check_direct_call = [&]() -> ExpressionResult { + auto func = ast->baseExpression; + + while (auto nested = ast_cast(func)) { + func = nested->expression; + } + + auto id = ast_cast(func); + if (!id) return {}; + + auto functionSymbol = symbol_cast(id->symbol); + + if (!functionSymbol) return {}; + + auto funcOp = gen.findOrCreateFunction(functionSymbol); + + mlir::SmallVector arguments; + for (auto node : ListView{ast->expressionList}) { + auto value = gen.expression(node); + arguments.push_back(value.value); + } + + auto loc = gen.getLocation(ast->lparenLoc); + + auto functionType = type_cast(functionSymbol->type()); + auto resultType = gen.convertType(functionType->returnType()); + auto op = gen.builder_.create( + loc, resultType, funcOp.getSymName(), arguments, mlir::ArrayAttr{}, + mlir::ArrayAttr{}); + + return {op}; + }; + + if (auto op = check_direct_call(); op.value) { + return op; + } + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); diff --git a/src/mlir/cxx/mlir/cxx_dialect.cc b/src/mlir/cxx/mlir/cxx_dialect.cc index d197d393..899ae6e8 100644 --- a/src/mlir/cxx/mlir/cxx_dialect.cc +++ b/src/mlir/cxx/mlir/cxx_dialect.cc @@ -122,11 +122,6 @@ void CxxDialect::initialize() { addInterface(); } -void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, - FunctionType type, ArrayRef attrs) { - buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); -} - void FuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 0637a5c4..76a7038f 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -81,6 +81,10 @@ class FuncOpLowering : public OpConversionPattern { auto func = rewriter.create(op.getLoc(), op.getSymName(), llvmFuncType); + if (op.getBody().empty()) { + func.setLinkage(LLVM::linkage::Linkage::External); + } + rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); @@ -101,6 +105,39 @@ class ReturnOpLowering : public OpConversionPattern { } }; +class CallOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + + SmallVector argumentTypes; + for (auto argType : op.getOperandTypes()) { + auto convertedType = typeConverter->convertType(argType); + if (!convertedType) { + return rewriter.notifyMatchFailure( + op, "failed to convert call argument type"); + } + argumentTypes.push_back(convertedType); + } + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert call result types"); + } + + auto llvmCallOp = rewriter.create( + op.getLoc(), resultType, adaptor.getCallee(), adaptor.getInputs()); + + rewriter.replaceOp(op, llvmCallOp.getResults()); + return success(); + } +}; + class AllocaOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -908,7 +945,8 @@ void CxxToLLVMLoweringPass::runOnOperation() { RewritePatternSet patterns(context); // function operations - patterns.insert(typeConverter, context); + patterns.insert( + typeConverter, context); // memory operations patterns.insert(