diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 290384c0..0aa6ca87 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -189,9 +189,9 @@ def Cxx_IntegralCastOp : Cxx_Op<"integral_cast"> { } def Cxx_NotOp : Cxx_Op<"not"> { - let arguments = (ins Cxx_BoolType:$value); + let arguments = (ins AnyType:$value); - let results = (outs Cxx_BoolType:$result); + let results = (outs AnyType:$result); } def Cxx_AddIOp : Cxx_Op<"addi"> { @@ -212,6 +212,70 @@ def Cxx_MulIOp : Cxx_Op<"muli"> { let results = (outs Cxx_IntegerType:$result); } +def Cxx_DivIOp : Cxx_Op<"divi"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_ModIOp : Cxx_Op<"mod"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_ShiftLeftOp : Cxx_Op<"shl"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_ShiftRightOp : Cxx_Op<"shr"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_EqualOp : Cxx_Op<"eq"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_NotEqualOp : Cxx_Op<"ne"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_LessThanOp : Cxx_Op<"lt"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_LessEqualOp : Cxx_Op<"le"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_GreaterThanOp : Cxx_Op<"gt"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_GreaterEqualOp : Cxx_Op<"ge"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +// +// control flow ops +// + def Cxx_LabelOp : Cxx_Op<"label"> { let arguments = (ins StringProp:$name); } diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index 7b910861..ddc18434 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -61,13 +61,13 @@ void Codegen::branch(mlir::Location loc, mlir::Block* block, } auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional { - auto var = symbol_cast(symbol); - if (!var) return std::nullopt; - - if (auto local = locals_.find(var); local != locals_.end()) { + if (auto local = locals_.find(symbol); local != locals_.end()) { return local->second; } + auto var = symbol_cast(symbol); + if (!var) return std::nullopt; + auto type = convertType(var->type()); auto ptrType = builder_.getType(type); diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index 8e30fb8f..ad9748af 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -375,6 +376,31 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) std::swap(gen.exitValue_, exitValue); std::swap(gen.locals_, locals); + FunctionParametersSymbol* params = nullptr; + for (auto member : ast->symbol->scope()->symbols()) { + params = symbol_cast(member); + if (!params) continue; + + int argc = 0; + auto args = entryBlock->getArguments(); + for (auto param : params->scope()->symbols()) { + auto arg = symbol_cast(param); + if (!arg) continue; + + auto type = gen.convertType(arg->type()); + auto ptrType = gen.builder_.getType(type); + + auto loc = gen.getLocation(arg->location()); + auto allocaOp = gen.builder_.create(loc, ptrType); + + auto value = args[argc]; + ++argc; + gen.builder_.create(loc, value, allocaOp); + + gen.locals_.emplace(arg, allocaOp); + } + } + // generate code for the function body auto functionBodyResult = gen.functionBody(ast->functionBody); diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index ca92b005..49dc7c41 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -678,7 +678,7 @@ auto Codegen::ExpressionVisitor::operator()(LabelAddressExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast) -> ExpressionResult { switch (ast->op) { - case cxx::TokenKind::T_EXCLAIM: { + case TokenKind::T_EXCLAIM: { if (type_cast(control()->remove_cv(ast->type))) { auto loc = gen.getLocation(ast->opLoc); auto expressionResult = gen.expression(ast->expression); @@ -690,6 +690,38 @@ auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast) break; } + case TokenKind::T_PLUS: { + // unary plus, no-op + auto expressionResult = gen.expression(ast->expression); + return expressionResult; + } + + case TokenKind::T_MINUS: { + // unary minus + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + + auto loc = gen.getLocation(ast->opLoc); + auto zero = + gen.builder_.create(loc, resultType, 0); + auto op = gen.builder_.create(loc, resultType, zero, + expressionResult.value); + + return {op}; + } + + case TokenKind::T_TILDE: { + // unary bitwise not + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + + auto loc = gen.getLocation(ast->opLoc); + auto op = gen.builder_.create(loc, resultType, + expressionResult.value); + + return {op}; + } + default: break; } // switch @@ -928,6 +960,116 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) break; } + case TokenKind::T_SLASH: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_PERCENT: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_LESS_LESS: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_GREATER_GREATER: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_EQUAL_EQUAL: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_EXCLAIM_EQUAL: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_LESS: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_LESS_EQUAL: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_GREATER: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_GREATER_EQUAL: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + default: break; } // switch diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 5d3d8e86..0637a5c4 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -31,6 +31,7 @@ #include #include #include +#include namespace mlir { @@ -361,8 +362,8 @@ class NotOpLowering : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "failed to convert not operation"); } - auto c1 = - rewriter.create(op.getLoc(), rewriter.getI1Type(), 1); + auto c1 = rewriter.create( + op.getLoc(), adaptor.getValue().getType(), -1); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue(), c1); @@ -440,6 +441,311 @@ class MulIOpLowering : public OpConversionPattern { } }; +class DivIOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::DivIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert divi operation type"); + } + + bool isSigned = true; + + if (auto intType = dyn_cast(op.getType())) { + isSigned = intType.getIsSigned(); + } + + if (isSigned) { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getLhs(), adaptor.getRhs()); + } else { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getLhs(), adaptor.getRhs()); + } + + return success(); + } +}; + +class ModIOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::ModIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert modi operation type"); + } + + bool isSigned = true; + + if (auto intType = dyn_cast(op.getType())) { + isSigned = intType.getIsSigned(); + } + + if (isSigned) { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getLhs(), adaptor.getRhs()); + } else { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getLhs(), adaptor.getRhs()); + } + + return success(); + } +}; + +class ShiftLeftOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::ShiftLeftOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert shift left operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class ShiftRightOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::ShiftRightOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert shift right operation type"); + } + + bool isSigned = true; + + if (auto intType = dyn_cast(op.getType())) { + isSigned = intType.getIsSigned(); + } + + if (isSigned) { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getLhs(), adaptor.getRhs()); + } else { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getLhs(), adaptor.getRhs()); + } + + return success(); + } +}; + +class EqualOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::EqualOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert equal operation type"); + } + + rewriter.replaceOpWithNewOp( + op, resultType, LLVM::ICmpPredicate::eq, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class NotEquaOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::NotEqualOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert not equal operation type"); + } + + rewriter.replaceOpWithNewOp( + op, resultType, LLVM::ICmpPredicate::ne, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class LessThanOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::LessThanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert less than operation type"); + } + + auto predicate = LLVM::ICmpPredicate::slt; + + if (auto intType = dyn_cast(op.getLhs().getType())) { + if (intType.getIsSigned()) { + predicate = LLVM::ICmpPredicate::slt; + } else { + predicate = LLVM::ICmpPredicate::ult; + } + } + + rewriter.replaceOpWithNewOp( + op, resultType, predicate, adaptor.getLhs(), adaptor.getRhs()); + + return success(); + } +}; + +class LessEqualOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::LessEqualOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert less equal operation type"); + } + + auto predicate = LLVM::ICmpPredicate::sle; + + if (auto intType = dyn_cast(op.getLhs().getType())) { + if (intType.getIsSigned()) { + predicate = LLVM::ICmpPredicate::sle; + } else { + predicate = LLVM::ICmpPredicate::ule; + } + } + + rewriter.replaceOpWithNewOp( + op, resultType, predicate, adaptor.getLhs(), adaptor.getRhs()); + + return success(); + } +}; + +class GreaterThanOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::GreaterThanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert greater than operation type"); + } + + auto predicate = LLVM::ICmpPredicate::sgt; + + if (auto intType = dyn_cast(op.getLhs().getType())) { + if (intType.getIsSigned()) { + predicate = LLVM::ICmpPredicate::sgt; + } else { + predicate = LLVM::ICmpPredicate::ugt; + } + } + + rewriter.replaceOpWithNewOp( + op, resultType, predicate, adaptor.getLhs(), adaptor.getRhs()); + + return success(); + } +}; + +class GreaterEqualOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::GreaterEqualOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert greater equal operation type"); + } + + auto predicate = LLVM::ICmpPredicate::sge; + + if (auto intType = dyn_cast(op.getLhs().getType())) { + if (intType.getIsSigned()) { + predicate = LLVM::ICmpPredicate::sge; + } else { + predicate = LLVM::ICmpPredicate::uge; + } + } + + rewriter.replaceOpWithNewOp( + op, resultType, predicate, adaptor.getLhs(), adaptor.getRhs()); + + return success(); + } +}; + class CondBranchOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -588,18 +894,6 @@ void CxxToLLVMLoweringPass::runOnOperation() { // set up the conversion patterns ConversionTarget target(*context); - target.addLegalDialect(); - target.addIllegalDialect(); - - RewritePatternSet patterns(context); - patterns.insert( - typeConverter, context); - LabelConverter labelConverter; module.walk([&](Operation *op) { @@ -608,6 +902,43 @@ void CxxToLLVMLoweringPass::runOnOperation() { } }); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(context); + + // function operations + patterns.insert(typeConverter, context); + + // memory operations + patterns.insert( + typeConverter, context); + + // cast operations + patterns + .insert( + typeConverter, context); + + // constant operations + patterns.insert(typeConverter, context); + + // unary operations + patterns.insert(typeConverter, context); + + // binary operations + patterns + .insert( + typeConverter, context); + + // comparison operations + patterns.insert(typeConverter, context); + + // control flow operations + patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, labelConverter, context); @@ -637,6 +968,7 @@ auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult { #endif pm.addPass(cxx::createLowerToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); if (failed(pm.run(module))) { module.print(llvm::errs()); diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index ffaf6a0b..46ea9882 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -6048,6 +6048,11 @@ auto Parser::parse_parameter_declaration_clause( auto Parser::parse_parameter_declaration_list( ParameterDeclarationClauseAST* ast) -> bool { + if (lookat(TokenKind::T_VOID, TokenKind::T_RPAREN)) { + consumeToken(); + return true; + } + auto it = &ast->parameterDeclarationList; auto _ = Binder::ScopeGuard{&binder_};