diff --git a/src/frontend/cxx/frontend.cc b/src/frontend/cxx/frontend.cc index 9eace9f6..4e02d18a 100644 --- a/src/frontend/cxx/frontend.cc +++ b/src/frontend/cxx/frontend.cc @@ -40,6 +40,7 @@ #ifdef CXX_WITH_MLIR #include #include +#include #endif #include @@ -376,6 +377,11 @@ auto runOnFile(const CLI& cli, const std::string& fileName) -> bool { auto ir = codegen(unit.ast()); + if (failed(lowerToMLIR(ir.module))) { + std::cerr << "cxx: failed to lower C++ AST to MLIR" << std::endl; + return false; + } + mlir::OpPrintingFlags flags; if (cli.opt_g) { flags.enableDebugInfo(true, true); diff --git a/src/mlir/cxx/mlir/CMakeLists.txt b/src/mlir/cxx/mlir/CMakeLists.txt index 388a62af..03aff80a 100644 --- a/src/mlir/cxx/mlir/CMakeLists.txt +++ b/src/mlir/cxx/mlir/CMakeLists.txt @@ -13,6 +13,7 @@ set(SOURCES codegen_units.cc convert_type.cc cxx_dialect.cc + cxx_dialect_conversions.cc ) add_library(cxx-mlir ${SOURCES}) @@ -31,6 +32,8 @@ target_link_libraries(cxx-mlir PUBLIC MLIRFuncDialect MLIRControlFlowDialect MLIRSCFDialect + MLIRPass + MLIRTransforms ) target_compile_definitions(cxx-mlir PUBLIC CXX_WITH_MLIR) diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index dbfd0e5b..2557519b 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -75,13 +75,24 @@ def Cxx_ArrayType : Cxx_Type<"Array", "array"> { let assemblyFormat = "`<` $elementType `,` $size `>`"; } -def Cxx_ClassType : Cxx_Type<"Class", "class"> { +def Cxx_ClassType : Cxx_Type<"Class", "class", [MutableType]> { + + let storageClass = "ClassTypeStorage"; + let genStorageClass = 0; + + let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; + let parameters = (ins StringRefParameter<"class name", [{ "" }]>:$name, OptionalArrayRefParameter<"mlir::Type">:$body ); - let assemblyFormat = "`<` $name `(` $body `)` `>`"; + let extraClassDeclaration = [{ + static auto getNamed(MLIRContext *context, StringRef name) -> ClassType; + auto setBody(ArrayRef types) -> LogicalResult; + }]; + } // ops @@ -97,8 +108,7 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> { let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type, CArg<"ArrayRef", "{}">:$attrs)>]; - let extraClassDeclaration = - [{ + let extraClassDeclaration = [{ auto getArgumentTypes() -> ArrayRef { return getFunctionType().getInputs(); } auto getResultTypes() -> ArrayRef { return getFunctionType().getResults(); } auto getCallableRegion() -> Region* { return &getBody(); } diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 49655632..3498bdc9 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -274,7 +274,7 @@ class Codegen { TranslationUnit* unit_ = nullptr; mlir::Block* exitBlock_ = nullptr; mlir::cxx::AllocaOp exitValue_; - std::unordered_map classNames_; + std::unordered_map classNames_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/convert_type.cc b/src/mlir/cxx/mlir/convert_type.cc index 75f58df0..f01b6212 100644 --- a/src/mlir/cxx/mlir/convert_type.cc +++ b/src/mlir/cxx/mlir/convert_type.cc @@ -268,7 +268,7 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type { if (auto it = gen.classNames_.find(classSymbol); it != gen.classNames_.end()) { - return mlir::cxx::ClassType::get(ctx, it->second, {}); + return it->second; } auto name = to_string(classSymbol->name()); @@ -277,7 +277,9 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type { name = std::format("$class_{}", loc.index()); } - gen.classNames_[classSymbol] = name; + mlir::cxx::ClassType classType = mlir::cxx::ClassType::getNamed(ctx, name); + + gen.classNames_[classSymbol] = classType; // todo: layout of parent classes, anonymous nested fields, etc. @@ -291,7 +293,7 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type { memberTypes.push_back(memberType); } - auto classType = mlir::cxx::ClassType::get(ctx, name, memberTypes); + classType.setBody(memberTypes); return classType; } diff --git a/src/mlir/cxx/mlir/cxx_dialect.cc b/src/mlir/cxx/mlir/cxx_dialect.cc index d6174697..0bb035e2 100644 --- a/src/mlir/cxx/mlir/cxx_dialect.cc +++ b/src/mlir/cxx/mlir/cxx_dialect.cc @@ -32,6 +32,43 @@ namespace mlir::cxx { +struct detail::ClassTypeStorage : public TypeStorage { + public: + using KeyTy = StringRef; + + explicit ClassTypeStorage(const KeyTy &key) : name_(key) {} + + auto getName() -> StringRef const { return name_; } + auto getBody() const -> ArrayRef { return body_; } + + auto operator==(const KeyTy &key) const -> bool { return name_ == key; }; + + static auto hashKey(const KeyTy &key) -> llvm::hash_code { + return llvm::hash_value(key); + } + + static ClassTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + ClassTypeStorage(allocator.copyInto(key)); + } + + auto mutate(TypeStorageAllocator &allocator, ArrayRef body) + -> LogicalResult { + if (isInitialized_) return success(body == getBody()); + + isInitialized_ = true; + body_ = allocator.copyInto(body); + + return success(); + } + + private: + StringRef name_; + ArrayRef body_; + bool isInitialized_ = false; +}; + namespace { struct CxxGenerateAliases : public OpAsmDialectInterface { @@ -39,29 +76,29 @@ struct CxxGenerateAliases : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; auto getAlias(Type type, raw_ostream &os) const -> AliasResult override { - if (auto intType = mlir::dyn_cast(type)) { + if (auto intType = dyn_cast(type)) { os << 'i' << intType.getWidth() << (intType.getIsSigned() ? 's' : 'u'); return AliasResult::FinalAlias; } - if (auto floatType = mlir::dyn_cast(type)) { + if (auto floatType = dyn_cast(type)) { os << 'f' << floatType.getWidth(); return AliasResult::FinalAlias; } - if (auto classType = mlir::dyn_cast(type)) { + if (auto classType = dyn_cast(type)) { if (!classType.getBody().empty()) { os << "class_" << classType.getName(); return AliasResult::FinalAlias; } } - if (mlir::isa(type)) { + if (isa(type)) { os << "void"; return AliasResult::FinalAlias; } - if (mlir::isa(type)) { + if (isa(type)) { os << "bool"; return AliasResult::FinalAlias; } @@ -85,32 +122,73 @@ void CxxDialect::initialize() { addInterface(); } -void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - llvm::StringRef name, mlir::FunctionType type, - llvm::ArrayRef attrs) { +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs) { buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); } -void FuncOp::print(mlir::OpAsmPrinter &p) { - mlir::function_interface_impl::printFunctionOp( +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } -auto FuncOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) - -> mlir::ParseResult { +auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult { auto funcTypeBuilder = - [](mlir::Builder &builder, llvm::ArrayRef argTypes, - llvm::ArrayRef results, - mlir::function_interface_impl::VariadicFlag, + [](Builder &builder, llvm::ArrayRef argTypes, + ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; - return mlir::function_interface_impl::parseFunctionOp( + return function_interface_impl::parseFunctionOp( parser, result, false, getFunctionTypeAttrName(result.name), funcTypeBuilder, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } +auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType { + return Base::get(context, name); +} + +auto ClassType::setBody(llvm::ArrayRef body) -> LogicalResult { + Base::mutate(body); +} + +void ClassType::print(AsmPrinter &p) const { + FailureOr cyclicPrint; + + p << "<"; + cyclicPrint = p.tryStartCyclicPrint(*this); + + p << '"'; + llvm::printEscapedString(getName(), p.getStream()); + p << '"'; + + if (failed(cyclicPrint)) { + p << '>'; + return; + } + + p << ", "; + + p << '('; + llvm::interleaveComma(getBody(), p.getStream(), + [&](Type subtype) { p << subtype; }); + p << ')'; + + p << '>'; +} + +auto ClassType::parse(AsmParser &parser) -> Type { + // todo: implement parsing for ClassType + return {}; +} + +auto ClassType::getName() const -> StringRef { return getImpl()->getName(); } + +auto ClassType::getBody() const -> ArrayRef { + return getImpl()->getBody(); +} + } // namespace mlir::cxx #include diff --git a/src/mlir/cxx/mlir/cxx_dialect.h b/src/mlir/cxx/mlir/cxx_dialect.h index df01cea7..eec85e4b 100644 --- a/src/mlir/cxx/mlir/cxx_dialect.h +++ b/src/mlir/cxx/mlir/cxx_dialect.h @@ -38,6 +38,12 @@ #pragma GCC diagnostic pop #endif +namespace mlir::cxx::detail { + +struct ClassTypeStorage; + +} + #include #define GET_TYPEDEF_CLASSES diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc new file mode 100644 index 00000000..368ae621 --- /dev/null +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Roberto Raggi +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +// cxx +#include + +// mlir +#include +#include +#include +#include + +namespace mlir { + +namespace { + +class FuncOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto func = rewriter.create(op.getLoc(), op.getName(), + op.getFunctionType()); + + rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); + rewriter.eraseOp(op); + return success(); + } +}; + +class ReturnOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + rewriter.replaceOpWithNewOp(op, + adaptor.getOperands()); + return success(); + } +}; + +class CxxToFuncLoweringPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CxxToFuncLoweringPass) + + auto getArgument() const -> StringRef override { return "cxx-to-func"; } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() final; +}; + +} // namespace + +void CxxToFuncLoweringPass::runOnOperation() { + ConversionTarget target(getContext()); + + // good + target.addLegalDialect(); + + // illegal + target.addIllegalOp(); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + + auto module = getOperation(); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace mlir + +auto cxx::createLowerToFuncPass() -> std::unique_ptr { + return std::make_unique(); +} + +auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult { + mlir::PassManager pm(module->getName()); + + pm.addPass(cxx::createLowerToFuncPass()); + + if (mlir::failed(pm.run(module))) { + return mlir::failure(); + } + + return mlir::success(); +} diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.h b/src/mlir/cxx/mlir/cxx_dialect_conversions.h new file mode 100644 index 00000000..2618d9cb --- /dev/null +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.h @@ -0,0 +1,32 @@ +// Copyright (c) 2025 Roberto Raggi +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include +#include + +namespace cxx { + +[[nodiscard]] auto createLowerToFuncPass() -> std::unique_ptr; + +[[nodiscard]] auto lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult; + +} // namespace cxx