From d1538c48f772dd19182d64ddfec969e432e811c6 Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Mon, 4 Aug 2025 21:29:58 +0200 Subject: [PATCH] Add MLIR ops and conversions for floating point operations --- src/mlir/cxx/mlir/CxxOps.td | 82 ++++++++ src/mlir/cxx/mlir/codegen_expressions.cc | 148 +++++++++++++- src/mlir/cxx/mlir/cxx_dialect_conversions.cc | 196 +++++++++++++++++++ src/parser/cxx/type_checker.cc | 3 + 4 files changed, 424 insertions(+), 5 deletions(-) diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 9a6dd7a3..4eeecd5e 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -280,6 +280,88 @@ def Cxx_GreaterEqualOp : Cxx_Op<"ge"> { let results = (outs Cxx_BoolType:$result); } +// +// float operations +// + +def Cxx_AddFOp : Cxx_Op<"addf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_FloatType:$result); +} + +def Cxx_SubFOp : Cxx_Op<"subf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_FloatType:$result); +} + +def Cxx_MulFOp : Cxx_Op<"mulf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_FloatType:$result); +} + +def Cxx_DivFOp : Cxx_Op<"divf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_FloatType:$result); +} + +def Cxx_FloatingPointCastOp : Cxx_Op<"floating_point_cast"> { + let arguments = (ins Cxx_FloatType:$value); + + let results = (outs Cxx_FloatType:$result); +} + +def Cxx_FloatToIntOp : Cxx_Op<"float_to_int"> { + let arguments = (ins Cxx_FloatType:$value); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_IntToFloatOp : Cxx_Op<"int_to_float"> { + let arguments = (ins Cxx_IntegerType:$value); + + let results = (outs Cxx_FloatType:$result); +} + +def Cxx_LessThanFOp : Cxx_Op<"ltf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_LessEqualFOp : Cxx_Op<"lef"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_GreaterThanFOp : Cxx_Op<"gtf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_GreaterEqualFOp : Cxx_Op<"gef"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_EqualFOp : Cxx_Op<"eqf"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + +def Cxx_NotEqualFOp : Cxx_Op<"nef"> { + let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs); + + let results = (outs Cxx_BoolType:$result); +} + // // control flow ops // diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 854d69dc..96275d92 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -743,12 +744,46 @@ auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast) auto resultType = gen.convertType(ast->type); auto loc = gen.getLocation(ast->opLoc); - auto zero = - gen.builder_.create(loc, resultType, 0); - auto op = gen.builder_.create(loc, resultType, zero, - expressionResult.value); - return {op}; + if (control()->is_integral_or_unscoped_enum(ast->type)) { + auto zero = + gen.builder_.create(loc, resultType, 0); + auto op = gen.builder_.create( + loc, resultType, zero, expressionResult.value); + + return {op}; + } + + if (control()->is_floating_point(ast->type)) { + resultType.dump(); + + mlir::FloatAttr value; + switch (ast->type->kind()) { + case TypeKind::kFloat: + value = gen.builder_.getF32FloatAttr(0); + break; + case TypeKind::kDouble: + value = gen.builder_.getF64FloatAttr(0); + break; + case TypeKind::kLongDouble: + value = gen.builder_.getF64FloatAttr(0); + break; + default: + // Handle other float types if necessary + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), + "unsupported float type"); + return {op}; + } + + auto zero = gen.builder_.create( + loc, resultType, value); + auto op = gen.builder_.create( + loc, resultType, zero, expressionResult.value); + + return {op}; + } + + break; } case TokenKind::T_TILDE: { @@ -943,6 +978,39 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast) return {op}; } + case ImplicitCastKind::kFloatingPointPromotion: + case ImplicitCastKind::kFloatingPointConversion: { + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + + // generate a floating point cast + auto op = gen.builder_.create( + loc, resultType, expressionResult.value); + + return {op}; + } + + case ImplicitCastKind::kFloatingIntegralConversion: { + auto expressionResult = gen.expression(ast->expression); + auto resultType = gen.convertType(ast->type); + + if (control()->is_floating_point(ast->type)) { + // If the result type is a floating point, we can use a specialized cast + auto op = gen.builder_.create( + loc, resultType, expressionResult.value); + return {op}; + } + + if (control()->is_integral(ast->type)) { + // If the expression type is an integral, we can use a specialized cast + auto op = gen.builder_.create( + loc, resultType, expressionResult.value); + return {op}; + } + + break; + } + default: break; @@ -986,6 +1054,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -997,6 +1072,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1008,6 +1090,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1019,6 +1108,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1063,6 +1159,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1074,6 +1177,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1085,6 +1195,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1096,6 +1213,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1107,6 +1231,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } @@ -1118,6 +1249,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) return {op}; } + if (control()->is_floating_point(ast->type)) { + auto op = gen.builder_.create( + loc, resultType, leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + break; } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index df9581e8..8763e67c 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -783,6 +783,193 @@ class GreaterEqualOpLowering : public OpConversionPattern { } }; +// +// floating point operations +// + +class AddFOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::AddFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert addf operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class SubFOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::SubFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert subf operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class MulFOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::MulFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert mulf operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class DivFOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::DivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert divf operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class FloatingPointCastOpLowering + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::FloatingPointCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert floating point cast type"); + } + + const auto sourceType = dyn_cast(op.getValue().getType()); + const auto targetType = dyn_cast(op.getType()); + + if (sourceType.getWidth() == targetType.getWidth()) { + // no conversion needed, just replace the op with the value + rewriter.replaceOp(op, adaptor.getValue()); + return success(); + } + + if (sourceType.getWidth() < targetType.getWidth()) { + // extension + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + return success(); + } + + // truncation + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + + return success(); + } +}; + +class IntToFloatOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::IntToFloatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert int to float type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + + return success(); + } +}; + +class FloatToIntOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::FloatToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert float to int type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + + return success(); + } +}; + +// +// control flow operations +// + class CondBranchOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -979,6 +1166,15 @@ void CxxToLLVMLoweringPass::runOnOperation() { LessEqualOpLowering, GreaterThanOpLowering, GreaterEqualOpLowering>(typeConverter, context); + // floating point operations + patterns + .insert( + typeConverter, context); + + // floating point cast operations + patterns.insert(typeConverter, context); + // control flow operations patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); diff --git a/src/parser/cxx/type_checker.cc b/src/parser/cxx/type_checker.cc index 19661448..4423b12c 100644 --- a/src/parser/cxx/type_checker.cc +++ b/src/parser/cxx/type_checker.cc @@ -1564,6 +1564,7 @@ auto TypeChecker::Visitor::floating_point_conversion( ExpressionAST*& expr, const Type* destinationType) -> bool { if (!is_prvalue(expr)) return false; + if (control()->is_same(expr->type, destinationType)) return true; if (!control()->is_floating_point(expr->type)) return false; if (!control()->is_floating_point(destinationType)) return false; @@ -1986,12 +1987,14 @@ auto TypeChecker::Visitor::usual_arithmetic_conversion(ExpressionAST*& expr, if (expr->type->kind() == TypeKind::kLongDouble || other->type->kind() == TypeKind::kLongDouble) { (void)floating_point_conversion(expr, control()->getLongDoubleType()); + (void)floating_point_conversion(other, control()->getLongDoubleType()); return control()->getLongDoubleType(); } if (expr->type->kind() == TypeKind::kDouble || other->type->kind() == TypeKind::kDouble) { (void)floating_point_conversion(expr, control()->getDoubleType()); + (void)floating_point_conversion(other, control()->getDoubleType()); return control()->getDoubleType(); }