diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 0d317992..578df640 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -113,6 +113,15 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> { let hasCustomAssemblyFormat = 1; } +def Cxx_GlobalOp : Cxx_Op<"global"> { + let arguments = (ins + TypeAttrOf:$global_type + , UnitAttr:$constant + , SymbolNameAttr:$sym_name + , OptionalAttr:$value + ); +} + def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> { let arguments = (ins Variadic:$input); @@ -121,8 +130,6 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> { let extraClassDeclaration = [{ bool hasOperand() { return getNumOperands() != 0; } }]; - - let hasVerifier = 0; } def Cxx_CallOp : Cxx_Op<"call"> { @@ -160,6 +167,12 @@ def Cxx_SubscriptOp : Cxx_Op<"subscript"> { let results = (outs Cxx_PointerType:$result); } +def Cxx_AddressOfOp : Cxx_Op<"addressof"> { + let arguments = (ins FlatSymbolRefAttr:$sym_name); + + let results = (outs AnyType:$result); +} + def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [ Pure ]> { diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index b901edb2..a050510e 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -56,6 +56,15 @@ auto Codegen::newBlock() -> mlir::Block* { return newBlock; } +auto Codegen::newUniqueSymbolName(std::string_view prefix) -> std::string { + auto& uniqueName = uniqueSymbolNames_[prefix]; + if (uniqueName == 0) { + uniqueName = 1; + return std::format("{}{}", prefix, uniqueName); + } + return std::format("{}{}", prefix, ++uniqueName); +} + void Codegen::branch(mlir::Location loc, mlir::Block* block, mlir::ValueRange operands) { if (currentBlockMightHaveTerminator()) return; diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 8702fee0..3800c7aa 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -269,6 +269,10 @@ class Codegen { -> std::optional; [[nodiscard]] auto newBlock() -> mlir::Block*; + + [[nodiscard]] auto newUniqueSymbolName(std::string_view prefix) + -> std::string; + void branch(mlir::Location loc, mlir::Block* block, mlir::ValueRange operands = {}); @@ -315,6 +319,8 @@ class Codegen { std::unordered_map classNames_; std::unordered_map locals_; std::unordered_map funcOps_; + std::unordered_map uniqueSymbolNames_; + std::unordered_map stringLiterals_; Loop loop_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 6f782b8d..aaa8b57f 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -254,8 +254,31 @@ auto Codegen::ExpressionVisitor::operator()(NullptrLiteralExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(StringLiteralExpressionAST* ast) -> ExpressionResult { + auto loc = gen.getLocation(ast->literalLoc); + auto type = gen.convertType(ast->type); + auto resultType = mlir::cxx::PointerType::get(type.getContext(), type); + + auto it = gen.stringLiterals_.find(ast->literal); + if (it == gen.stringLiterals_.end()) { + // todo: clean up + std::string str(ast->literal->stringValue()); + str.push_back('\0'); + + auto initializer = gen.builder_.getStringAttr(str); + + // todo: generate unique name for the global + auto name = gen.builder_.getStringAttr(gen.newUniqueSymbolName(".str")); + + auto x = mlir::OpBuilder(gen.module_->getContext()); + x.setInsertionPointToEnd(gen.module_.getBody()); + x.create(loc, type, true, name, initializer); + + it = gen.stringLiterals_.insert_or_assign(ast->literal, name).first; + } + auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); + gen.builder_.create(loc, resultType, it->second); + return {op}; } @@ -1024,6 +1047,11 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast) return {op}; } + case ImplicitCastKind::kQualificationConversion: { + auto expressionResult = gen.expression(ast->expression); + return expressionResult; + } + default: break; diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 0af89bce..61512b80 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -100,6 +100,25 @@ class FuncOpLowering : public OpConversionPattern { } }; +class GlobalOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::GlobalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + + auto elementType = getTypeConverter()->convertType(op.getGlobalType()); + + rewriter.replaceOpWithNewOp( + op, elementType, op.getConstant(), LLVM::linkage::Linkage::Private, + op.getSymName(), adaptor.getValue().value()); + + return success(); + } +}; + class ReturnOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -145,6 +164,28 @@ class CallOpLowering : public OpConversionPattern { } }; +class AddressOfOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::AddressOfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert address of type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSymName()); + + return success(); + } +}; + class AllocaOpLowering : public OpConversionPattern { public: AllocaOpLowering(const TypeConverter &typeConverter, @@ -461,7 +502,6 @@ class ArrayToPointerOpLowering SmallVector indices; - indices.push_back(0); indices.push_back(0); auto resultType = LLVM::LLVMPointerType::get(context); @@ -1254,8 +1294,8 @@ void CxxToLLVMLoweringPass::runOnOperation() { RewritePatternSet patterns(context); // function operations - patterns.insert( - typeConverter, context); + patterns.insert(typeConverter, context); // memory operations DataLayout dataLayout{module};