From 7161945b0e99e3295ccb902afcc8572422df8e27 Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Sun, 3 Aug 2025 10:03:17 +0200 Subject: [PATCH] Add AddI, SubI, MulI and bool conversions with corresponding lowering patterns Signed-off-by: Roberto Raggi --- src/mlir/cxx/mlir/CxxOps.td | 40 ++++- src/mlir/cxx/mlir/codegen_expressions.cc | 93 +++++++++-- src/mlir/cxx/mlir/cxx_dialect_conversions.cc | 156 ++++++++++++++++++- 3 files changed, 268 insertions(+), 21 deletions(-) diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index d0e03cbe..64f181d2 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -170,10 +170,46 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [ let results = (outs Cxx_FloatType:$result); } -def Cxx_IntegralCastOp : Cxx_Op<"cast.integral"> { +def Cxx_IntToBoolOp : Cxx_Op<"int_to_bool"> { let arguments = (ins AnyType:$value); - let results = (outs AnyType:$result); + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_BoolToIntOp : Cxx_Op<"bool_to_int"> { + let arguments = (ins Cxx_BoolType:$value); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_IntegralCastOp : Cxx_Op<"integral_cast"> { + let arguments = (ins Cxx_IntegerType:$value); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_NotOp : Cxx_Op<"not"> { + let arguments = (ins Cxx_BoolType:$value); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_AddIOp : Cxx_Op<"addi"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_SubIOp : Cxx_Op<"subi"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_MulIOp : Cxx_Op<"muli"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); } // diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 75aaf1ea..5834bd86 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -23,6 +23,7 @@ // cxx #include #include +#include #include #include #include @@ -34,6 +35,12 @@ struct Codegen::ExpressionVisitor { Codegen& gen; ExpressionFormat format = ExpressionFormat::kValue; + [[nodiscard]] auto control() const -> Control* { return gen.control(); } + + [[nodiscard]] auto is_bool(const Type* type) const -> bool { + return type_cast(control()->remove_cv(type)); + } + auto operator()(GeneratedLiteralExpressionAST* ast) -> ExpressionResult; auto operator()(CharLiteralExpressionAST* ast) -> ExpressionResult; auto operator()(BoolLiteralExpressionAST* ast) -> ExpressionResult; @@ -633,13 +640,26 @@ auto Codegen::ExpressionVisitor::operator()(LabelAddressExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast) -> ExpressionResult { + switch (ast->op) { + case cxx::TokenKind::T_EXCLAIM: { + if (type_cast(control()->remove_cv(ast->type))) { + auto loc = gen.getLocation(ast->opLoc); + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + auto op = gen.builder_.create(loc, resultType, + expressionResult.value); + return {op}; + } + break; + } + + default: + break; + } // switch + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); -#if false - auto expressionResult = gen.expression(ast->expression); -#endif - return {op}; } @@ -779,10 +799,24 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast) case ImplicitCastKind::kIntegralConversion: case ImplicitCastKind::kIntegralPromotion: { - // generate a cast auto expressionResult = gen.expression(ast->expression); auto resultType = gen.convertType(ast->type); + if (is_bool(ast->type)) { + // If the result type is a boolean, we can use a specialized cast + auto op = gen.builder_.create( + loc, resultType, expressionResult.value); + return {op}; + } + + if (is_bool(ast->expression->type)) { + // If the expression type is a boolean, we can use a specialized cast + auto op = gen.builder_.create( + loc, resultType, expressionResult.value); + return {op}; + } + + // generate an integral cast auto op = gen.builder_.create( loc, resultType, expressionResult.value); @@ -818,20 +852,51 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return gen.expression(ast->rightExpression, format); } - auto op = - gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); - -#if false + auto loc = gen.getLocation(ast->opLoc); auto leftExpressionResult = gen.expression(ast->leftExpression); auto rightExpressionResult = gen.expression(ast->rightExpression); + auto resultType = gen.convertType(ast->type); - auto loc = gen.getLocation(ast->opLoc); + switch (ast->op) { + case TokenKind::T_PLUS: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } - auto operation = Token::spell(ast->op); + break; + } - auto op = gen.builder_.create( - loc, operation, leftExpressionResult.value, rightExpressionResult.value); -#endif + case TokenKind::T_MINUS: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + case TokenKind::T_STAR: { + if (control()->is_integral(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + break; + } + + default: + break; + } // switch + + auto op = + gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); return {op}; } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 292b035b..39fa7fba 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -297,6 +297,149 @@ class IntegralCastOpLowering : public OpConversionPattern { } }; +class IntToBoolOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::IntToBoolOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert int to bool type"); + } + + auto c0 = rewriter.create( + op.getLoc(), adaptor.getValue().getType(), 0); + + rewriter.replaceOpWithNewOp( + op, resultType, LLVM::ICmpPredicate::ne, adaptor.getValue(), c0); + + return success(); + } +}; + +class BoolToIntOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::BoolToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert bool to int type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + + return success(); + } +}; + +class NotOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::NotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, "failed to convert not operation"); + } + + auto c1 = + rewriter.create(op.getLoc(), rewriter.getI1Type(), 1); + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue(), + c1); + + return success(); + } +}; + +class AddIOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert addi operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class SubIOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert subi operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class MulIOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert muli operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + class CxxToLLVMLoweringPass : public PassWrapper> { public: @@ -325,7 +468,7 @@ void CxxToLLVMLoweringPass::runOnOperation() { typeConverter.addConversion([](cxx::BoolType type) { // todo: i8/i32 for data and i1 for control flow - return IntegerType::get(type.getContext(), 8); + return IntegerType::get(type.getContext(), 1); }); typeConverter.addConversion([](cxx::IntegerType type) { @@ -383,10 +526,13 @@ void CxxToLLVMLoweringPass::runOnOperation() { target.addIllegalDialect(); RewritePatternSet patterns(context); - patterns.insert(typeConverter, context); + patterns + .insert( + typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);