Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,46 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
let results = (outs Cxx_FloatType:$result);
}

def Cxx_IntegralCastOp : Cxx_Op<"cast.integral"> {
def Cxx_IntToBoolOp : Cxx_Op<"int_to_bool"> {
let arguments = (ins AnyType:$value);

let results = (outs AnyType:$result);
let results = (outs Cxx_BoolType:$result);
}

def Cxx_BoolToIntOp : Cxx_Op<"bool_to_int"> {
let arguments = (ins Cxx_BoolType:$value);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_IntegralCastOp : Cxx_Op<"integral_cast"> {
let arguments = (ins Cxx_IntegerType:$value);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_NotOp : Cxx_Op<"not"> {
let arguments = (ins Cxx_BoolType:$value);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_AddIOp : Cxx_Op<"addi"> {
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_SubIOp : Cxx_Op<"subi"> {
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_MulIOp : Cxx_Op<"muli"> {
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);

let results = (outs Cxx_IntegerType:$result);
}

//
Expand Down
93 changes: 79 additions & 14 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
// cxx
#include <cxx/ast.h>
#include <cxx/ast_interpreter.h>
#include <cxx/control.h>
#include <cxx/literals.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
Expand All @@ -34,6 +35,12 @@ struct Codegen::ExpressionVisitor {
Codegen& gen;
ExpressionFormat format = ExpressionFormat::kValue;

[[nodiscard]] auto control() const -> Control* { return gen.control(); }

[[nodiscard]] auto is_bool(const Type* type) const -> bool {
return type_cast<BoolType>(control()->remove_cv(type));
}

auto operator()(GeneratedLiteralExpressionAST* ast) -> ExpressionResult;
auto operator()(CharLiteralExpressionAST* ast) -> ExpressionResult;
auto operator()(BoolLiteralExpressionAST* ast) -> ExpressionResult;
Expand Down Expand Up @@ -633,13 +640,26 @@ auto Codegen::ExpressionVisitor::operator()(LabelAddressExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast)
-> ExpressionResult {
switch (ast->op) {
case cxx::TokenKind::T_EXCLAIM: {
if (type_cast<BoolType>(control()->remove_cv(ast->type))) {
auto loc = gen.getLocation(ast->opLoc);
auto expressionResult = gen.expression(ast->expression);
auto resultType = gen.convertType(ast->type);
auto op = gen.builder_.create<mlir::cxx::NotOp>(loc, resultType,
expressionResult.value);
return {op};
}
break;
}

default:
break;
} // switch

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto expressionResult = gen.expression(ast->expression);
#endif

return {op};
}

Expand Down Expand Up @@ -779,10 +799,24 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast)

case ImplicitCastKind::kIntegralConversion:
case ImplicitCastKind::kIntegralPromotion: {
// generate a cast
auto expressionResult = gen.expression(ast->expression);
auto resultType = gen.convertType(ast->type);

if (is_bool(ast->type)) {
// If the result type is a boolean, we can use a specialized cast
auto op = gen.builder_.create<mlir::cxx::IntToBoolOp>(
loc, resultType, expressionResult.value);
return {op};
}

if (is_bool(ast->expression->type)) {
// If the expression type is a boolean, we can use a specialized cast
auto op = gen.builder_.create<mlir::cxx::BoolToIntOp>(
loc, resultType, expressionResult.value);
return {op};
}

// generate an integral cast
auto op = gen.builder_.create<mlir::cxx::IntegralCastOp>(
loc, resultType, expressionResult.value);

Expand Down Expand Up @@ -818,20 +852,51 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
return gen.expression(ast->rightExpression, format);
}

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto loc = gen.getLocation(ast->opLoc);
auto leftExpressionResult = gen.expression(ast->leftExpression);
auto rightExpressionResult = gen.expression(ast->rightExpression);
auto resultType = gen.convertType(ast->type);

auto loc = gen.getLocation(ast->opLoc);
switch (ast->op) {
case TokenKind::T_PLUS: {
if (control()->is_integral(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::AddIOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

auto operation = Token::spell(ast->op);
break;
}

auto op = gen.builder_.create<mlir::cxx::BinOp>(
loc, operation, leftExpressionResult.value, rightExpressionResult.value);
#endif
case TokenKind::T_MINUS: {
if (control()->is_integral(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::SubIOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

case TokenKind::T_STAR: {
if (control()->is_integral(ast->type)) {
auto op = gen.builder_.create<mlir::cxx::MulIOp>(
loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

break;
}

default:
break;
} // switch

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

return {op};
}
Expand Down
156 changes: 151 additions & 5 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,149 @@ class IntegralCastOpLowering : public OpConversionPattern<cxx::IntegralCastOp> {
}
};

class IntToBoolOpLowering : public OpConversionPattern<cxx::IntToBoolOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::IntToBoolOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op,
"failed to convert int to bool type");
}

auto c0 = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), adaptor.getValue().getType(), 0);

rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, resultType, LLVM::ICmpPredicate::ne, adaptor.getValue(), c0);

return success();
}
};

class BoolToIntOpLowering : public OpConversionPattern<cxx::BoolToIntOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::BoolToIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op,
"failed to convert bool to int type");
}

rewriter.replaceOpWithNewOp<LLVM::ZExtOp>(op, resultType,
adaptor.getValue());

return success();
}
};

class NotOpLowering : public OpConversionPattern<cxx::NotOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::NotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op, "failed to convert not operation");
}

auto c1 =
rewriter.create<LLVM::ConstantOp>(op.getLoc(), rewriter.getI1Type(), 1);

rewriter.replaceOpWithNewOp<LLVM::XOrOp>(op, resultType, adaptor.getValue(),
c1);

return success();
}
};

class AddIOpLowering : public OpConversionPattern<cxx::AddIOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::AddIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert addi operation type");
}

rewriter.replaceOpWithNewOp<LLVM::AddOp>(op, resultType, adaptor.getLhs(),
adaptor.getRhs());

return success();
}
};

class SubIOpLowering : public OpConversionPattern<cxx::SubIOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::SubIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert subi operation type");
}

rewriter.replaceOpWithNewOp<LLVM::SubOp>(op, resultType, adaptor.getLhs(),
adaptor.getRhs());

return success();
}
};

class MulIOpLowering : public OpConversionPattern<cxx::MulIOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::MulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert muli operation type");
}

rewriter.replaceOpWithNewOp<LLVM::MulOp>(op, resultType, adaptor.getLhs(),
adaptor.getRhs());

return success();
}
};

class CxxToLLVMLoweringPass
: public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
public:
Expand Down Expand Up @@ -325,7 +468,7 @@ void CxxToLLVMLoweringPass::runOnOperation() {

typeConverter.addConversion([](cxx::BoolType type) {
// todo: i8/i32 for data and i1 for control flow
return IntegerType::get(type.getContext(), 8);
return IntegerType::get(type.getContext(), 1);
});

typeConverter.addConversion([](cxx::IntegerType type) {
Expand Down Expand Up @@ -383,10 +526,13 @@ void CxxToLLVMLoweringPass::runOnOperation() {
target.addIllegalDialect<cxx::CxxDialect>();

RewritePatternSet patterns(context);
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
IntConstantOpLowering, FloatConstantOpLowering,
IntegralCastOpLowering>(typeConverter, context);
patterns
.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
IntConstantOpLowering, FloatConstantOpLowering,
IntToBoolOpLowering, BoolToIntOpLowering, IntegralCastOpLowering,
NotOpLowering, AddIOpLowering, SubIOpLowering, MulIOpLowering>(
typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
typeConverter);
Expand Down
Loading