Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,88 @@ def Cxx_GreaterEqualOp : Cxx_Op<"ge"> {
let results = (outs Cxx_BoolType:$result);
}

//
// float operations
//

def Cxx_AddFOp : Cxx_Op<"addf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_FloatType:$result);
}

def Cxx_SubFOp : Cxx_Op<"subf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_FloatType:$result);
}

def Cxx_MulFOp : Cxx_Op<"mulf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_FloatType:$result);
}

def Cxx_DivFOp : Cxx_Op<"divf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_FloatType:$result);
}

def Cxx_FloatingPointCastOp : Cxx_Op<"floating_point_cast"> {
let arguments = (ins Cxx_FloatType:$value);

let results = (outs Cxx_FloatType:$result);
}

def Cxx_FloatToIntOp : Cxx_Op<"float_to_int"> {
let arguments = (ins Cxx_FloatType:$value);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_IntToFloatOp : Cxx_Op<"int_to_float"> {
let arguments = (ins Cxx_IntegerType:$value);

let results = (outs Cxx_FloatType:$result);
}

def Cxx_LessThanFOp : Cxx_Op<"ltf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_LessEqualFOp : Cxx_Op<"lef"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_GreaterThanFOp : Cxx_Op<"gtf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_GreaterEqualFOp : Cxx_Op<"gef"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_EqualFOp : Cxx_Op<"eqf"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_NotEqualFOp : Cxx_Op<"nef"> {
let arguments = (ins Cxx_FloatType:$lhs, Cxx_FloatType:$rhs);

let results = (outs Cxx_BoolType:$result);
}

//
// control flow ops
//
Expand Down
148 changes: 143 additions & 5 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cxx/ast_interpreter.h>
#include <cxx/control.h>
#include <cxx/literals.h>
#include <cxx/memory_layout.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
Expand Down Expand Up @@ -743,12 +744,46 @@ auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast)
auto resultType = gen.convertType(ast->type);

auto loc = gen.getLocation(ast->opLoc);
auto zero =
gen.builder_.create<mlir::cxx::IntConstantOp>(loc, resultType, 0);
auto op = gen.builder_.create<mlir::cxx::SubIOp>(loc, resultType, zero,
expressionResult.value);

return {op};
if (control()->is_integral_or_unscoped_enum(ast->type)) {
auto zero =
gen.builder_.create<mlir::cxx::IntConstantOp>(loc, resultType, 0);
auto op = gen.builder_.create<mlir::cxx::SubIOp>(
loc, resultType, zero, expressionResult.value);

return {op};
}

if (control()->is_floating_point(ast->type)) {
resultType.dump();

mlir::FloatAttr value;
switch (ast->type->kind()) {
case TypeKind::kFloat:
value = gen.builder_.getF32FloatAttr(0);
break;
case TypeKind::kDouble:
value = gen.builder_.getF64FloatAttr(0);
break;
case TypeKind::kLongDouble:
value = gen.builder_.getF64FloatAttr(0);
break;
default:
// Handle other float types if necessary
auto op = gen.emitTodoExpr(ast->firstSourceLocation(),
"unsupported float type");
return {op};
}

auto zero = gen.builder_.create<mlir::cxx::FloatConstantOp>(
loc, resultType, value);
auto op = gen.builder_.create<mlir::cxx::SubFOp>(
loc, resultType, zero, expressionResult.value);

return {op};
}

break;
}

case TokenKind::T_TILDE: {
Expand Down Expand Up @@ -943,6 +978,39 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast)
return {op};
}

case ImplicitCastKind::kFloatingPointPromotion:
case ImplicitCastKind::kFloatingPointConversion: {
auto expressionResult = gen.expression(ast->expression);
auto resultType = gen.convertType(ast->type);

// generate a floating point cast
auto op = gen.builder_.create<mlir::cxx::FloatingPointCastOp>(
loc, resultType, expressionResult.value);

return {op};
}

case ImplicitCastKind::kFloatingIntegralConversion: {
auto expressionResult = gen.expression(ast->expression);
auto resultType = gen.convertType(ast->type);

if (control()->is_floating_point(ast->type)) {
// If the result type is a floating point, we can use a specialized cast
auto op = gen.builder_.create<mlir::cxx::IntToFloatOp>(
loc, resultType, expressionResult.value);
return {op};
}

if (control()->is_integral(ast->type)) {
// If the expression type is an integral, we can use a specialized cast
auto op = gen.builder_.create<mlir::cxx::FloatToIntOp>(
loc, resultType, expressionResult.value);
return {op};
}

break;
}

default:
break;

Expand Down Expand Up @@ -986,6 +1054,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::AddFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -997,6 +1072,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::SubFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1008,6 +1090,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::MulFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1019,6 +1108,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::DivFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand Down Expand Up @@ -1063,6 +1159,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::EqualFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1074,6 +1177,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::NotEqualFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1085,6 +1195,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::LessThanFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1096,6 +1213,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::LessEqualFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1107,6 +1231,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::GreaterThanFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand All @@ -1118,6 +1249,13 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return {op};
}

if (control()->is_floating_point(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::GreaterEqualFOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

Expand Down
Loading
Loading