diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 4eeecd5e..27a311ce 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -154,6 +154,12 @@ def Cxx_StoreOp : Cxx_Op<"store"> { let hasVerifier = 1; } +def Cxx_SubscriptOp : Cxx_Op<"subscript"> { + let arguments = (ins Cxx_PointerType:$base, AnyType:$index); + + let results = (outs Cxx_PointerType:$result); +} + def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [ Pure ]> { diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index 739a00ea..04b90d0f 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -40,6 +40,8 @@ namespace cxx { struct Codegen::DeclarationVisitor { Codegen& gen; + void allocateLocals(ScopedSymbol* block); + auto operator()(SimpleDeclarationAST* ast) -> DeclarationResult; auto operator()(AsmDeclarationAST* ast) -> DeclarationResult; auto operator()(NamespaceAliasDefinitionAST* ast) -> DeclarationResult; @@ -140,6 +142,26 @@ auto Codegen::lambdaSpecifier(LambdaSpecifierAST* ast) return {}; } +void Codegen::DeclarationVisitor::allocateLocals(ScopedSymbol* block) { + for (auto symbol : block->scope()->symbols()) { + if (auto nestedBlock = symbol_cast(symbol)) { + allocateLocals(nestedBlock); + continue; + } + + if (auto var = symbol_cast(symbol)) { + if (var->isStatic()) continue; + + auto local = gen.findOrCreateLocal(var); + if (!local.has_value()) { + gen.unit_->error(var->location(), + std::format("cannot allocate local variable '{}'", + to_string(var->name()))); + } + } + } +} + auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast) -> DeclarationResult { if (!gen.function_) { @@ -377,6 +399,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) } } + allocateLocals(functionSymbol); + // generate code for the function body auto functionBodyResult = gen.functionBody(ast->functionBody); diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 339d374b..9e0039ec 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -453,13 +453,15 @@ auto Codegen::ExpressionVisitor::operator()(VaArgExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(SubscriptExpressionAST* ast) -> ExpressionResult { - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); - -#if false auto baseExpressionResult = gen.expression(ast->baseExpression); auto indexExpressionResult = gen.expression(ast->indexExpression); -#endif + + auto loc = gen.getLocation(ast->firstSourceLocation()); + + auto resultType = gen.convertType(control()->add_pointer(ast->type)); + + auto op = gen.builder_.create( + loc, resultType, baseExpressionResult.value, indexExpressionResult.value); return {op}; } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 8763e67c..3e9fa0b8 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -236,6 +236,43 @@ class BoolConstantOpLowering : public OpConversionPattern { } }; +class SubscriptOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::SubscriptOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto ptrType = dyn_cast_or_null(op.getBase().getType()); + + if (!ptrType) { + return rewriter.notifyMatchFailure( + op, "failed to convert subscript operation type"); + } + + auto arrayType = dyn_cast_or_null(ptrType.getElementType()); + if (!arrayType) { + return rewriter.notifyMatchFailure( + op, "expected base type of subscript to be an array type"); + } + + SmallVector indices; + + indices.push_back(adaptor.getIndex()); + + auto resultType = LLVM::LLVMPointerType::get(context); + auto elementType = typeConverter->convertType(ptrType.getElementType()); + + rewriter.replaceOpWithNewOp(op, resultType, elementType, + adaptor.getBase(), indices); + + return success(); + } +}; + class IntConstantOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1093,6 +1130,13 @@ void CxxToLLVMLoweringPass::runOnOperation() { return LLVM::LLVMPointerType::get(type.getContext()); }); + typeConverter.addConversion([&](cxx::ArrayType type) -> Type { + auto elementType = typeConverter.convertType(type.getElementType()); + auto size = type.getSize(); + + return LLVM::LLVMArrayType::get(elementType, size); + }); + DenseMap convertedClassTypes; typeConverter.addConversion([&](cxx::ClassType type) -> Type { if (auto it = convertedClassTypes.find(type); @@ -1140,8 +1184,8 @@ void CxxToLLVMLoweringPass::runOnOperation() { typeConverter, context); // memory operations - patterns.insert( - typeConverter, context); + patterns.insert(typeConverter, context); // cast operations patterns @@ -1207,6 +1251,7 @@ auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult { pm.addPass(cxx::createLowerToLLVMPass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); if (failed(pm.run(module))) { module.print(llvm::errs());