Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

def Cxx_Dialect : Dialect {
let name = "cxx";
Expand Down Expand Up @@ -148,23 +149,23 @@ def Cxx_StoreOp : Cxx_Op<"store"> {
def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [
Pure
]> {
let arguments = (ins BoolAttr:$value);
let arguments = (ins BoolProp:$value);

let results = (outs Cxx_BoolType:$result);
}

def Cxx_IntConstantOp : Cxx_Op<"constant.int", [
Pure
]> {
let arguments = (ins I64Attr:$value);
let arguments = (ins I64Prop:$value);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
Pure
]> {
let arguments = (ins F64Attr:$value);
let arguments = (ins TypedAttrInterface:$value);

let results = (outs Cxx_FloatType:$result);
}
Expand All @@ -174,7 +175,7 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [
//

def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> {
let arguments = (ins StrAttr:$message);
let arguments = (ins StringProp:$message);
let results = (outs Cxx_ExprType:$result);
let assemblyFormat = "$message attr-dict `:` type($result)";
let builders =
Expand All @@ -185,7 +186,7 @@ def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> {
}

def Cxx_TodoStmtOp : Cxx_Op<"todo.stmt"> {
let arguments = (ins StrAttr:$message);
let arguments = (ins StringProp:$message);
let results = (outs);
let assemblyFormat = "$message attr-dict";
}
30 changes: 24 additions & 6 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
// cxx
#include <cxx/ast.h>
#include <cxx/literals.h>
#include <cxx/types.h>

namespace cxx {

Expand Down Expand Up @@ -133,8 +134,7 @@ auto Codegen::ExpressionVisitor::operator()(CharLiteralExpressionAST* ast)
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getI64IntegerAttr(ast->literal->charValue());

auto value = std::int64_t(ast->literal->charValue());
auto op = gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, value);

return {op};
Expand All @@ -145,9 +145,9 @@ auto Codegen::ExpressionVisitor::operator()(BoolLiteralExpressionAST* ast)
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getBoolAttr(ast->isTrue);

auto op = gen.builder_.create<mlir::cxx::BoolConstantOp>(loc, type, value);
auto op =
gen.builder_.create<mlir::cxx::BoolConstantOp>(loc, type, ast->isTrue);

return {op};
}
Expand All @@ -157,7 +157,7 @@ auto Codegen::ExpressionVisitor::operator()(IntLiteralExpressionAST* ast)
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getI64IntegerAttr(ast->literal->integerValue());
auto value = ast->literal->integerValue();

auto op = gen.builder_.create<mlir::cxx::IntConstantOp>(loc, type, value);

Expand All @@ -169,7 +169,25 @@ auto Codegen::ExpressionVisitor::operator()(FloatLiteralExpressionAST* ast)
auto loc = gen.getLocation(ast->literalLoc);

auto type = gen.convertType(ast->type);
auto value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());

mlir::TypedAttr value;

switch (ast->type->kind()) {
case TypeKind::kFloat:
value = gen.builder_.getF32FloatAttr(ast->literal->floatValue());
break;
case TypeKind::kDouble:
value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());
break;
case TypeKind::kLongDouble:
value = gen.builder_.getF64FloatAttr(ast->literal->floatValue());
break;
default:
// Handle other float types if necessary
auto op = gen.emitTodoExpr(ast->firstSourceLocation(),
"unsupported float type");
return {op};
}

auto op = gen.builder_.create<mlir::cxx::FloatConstantOp>(loc, type, value);

Expand Down
6 changes: 4 additions & 2 deletions src/mlir/cxx/mlir/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,14 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {
}

auto Codegen::ConvertType::operator()(const EnumType* type) -> mlir::Type {
return getExprType();
if (type->underlyingType()) return gen.convertType(type->underlyingType());
return gen.builder_.getType<mlir::cxx::IntegerType>(32, true);
}

auto Codegen::ConvertType::operator()(const ScopedEnumType* type)
-> mlir::Type {
return getExprType();
if (type->underlyingType()) return gen.convertType(type->underlyingType());
return gen.builder_.getType<mlir::cxx::IntegerType>(32, true);
}

auto Codegen::ConvertType::operator()(const MemberObjectPointerType* type)
Expand Down
102 changes: 98 additions & 4 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,29 @@ class StoreOpLowering : public OpConversionPattern<cxx::StoreOp> {
}
};

class BoolConstantOpLowering : public OpConversionPattern<cxx::BoolConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::BoolConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert boolean constant type");
}

rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType,
adaptor.getValue());

return success();
}
};

class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -186,8 +209,32 @@ class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
op, "failed to convert integer constant type");
}

auto valueAttr = adaptor.getValueAttr();
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType, valueAttr);
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType,
adaptor.getValue());

return success();
}
};

class FloatConstantOpLowering
: public OpConversionPattern<cxx::FloatConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::FloatConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert float constant type");
}

rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, resultType,
adaptor.getValue());

return success();
}
Expand Down Expand Up @@ -218,14 +265,60 @@ void CxxToLLVMLoweringPass::runOnOperation() {

// set up the type converter
LLVMTypeConverter typeConverter{context};

typeConverter.addConversion([](cxx::BoolType type) {
// todo: i8/i32 for data and i1 for control flow
return IntegerType::get(type.getContext(), 8);
});

typeConverter.addConversion([](cxx::IntegerType type) {
return IntegerType::get(type.getContext(), type.getWidth());
});

typeConverter.addConversion([](cxx::FloatType type) -> Type {
auto width = type.getWidth();
switch (width) {
case 16:
return Float16Type::get(type.getContext());
case 32:
return Float32Type::get(type.getContext());
case 64:
return Float64Type::get(type.getContext());
default:
return {};
} // switch
});

typeConverter.addConversion([](cxx::PointerType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});

DenseMap<cxx::ClassType, Type> convertedClassTypes;
typeConverter.addConversion([&](cxx::ClassType type) -> Type {
if (auto it = convertedClassTypes.find(type);
it != convertedClassTypes.end()) {
return it->second;
}

auto structType =
LLVM::LLVMStructType::getIdentified(type.getContext(), type.getName());

convertedClassTypes[type] = structType;

SmallVector<Type> fieldTypes;
bool isPacked = false;

for (auto field : type.getBody()) {
auto convertedFieldType = typeConverter.convertType(field);
// todo: check if the field type was converted successfully
fieldTypes.push_back(convertedFieldType);
}

structType.setBody(fieldTypes, isPacked);

return structType;
});

// set up the conversion patterns
ConversionTarget target(*context);

Expand All @@ -234,8 +327,9 @@ void CxxToLLVMLoweringPass::runOnOperation() {

RewritePatternSet patterns(context);
patterns.insert<FuncOpLowering, ReturnOpLowering, AllocaOpLowering,
LoadOpLowering, StoreOpLowering, IntConstantOpLowering>(
typeConverter, context);
LoadOpLowering, StoreOpLowering, BoolConstantOpLowering,
IntConstantOpLowering, FloatConstantOpLowering>(typeConverter,
context);

populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,
typeConverter);
Expand Down
5 changes: 5 additions & 0 deletions src/parser/cxx/decl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ struct GetDeclaratorType {
if (auto params = ast->parameterDeclarationClause) {
for (auto it = params->parameterDeclarationList; it; it = it->next) {
auto paramType = it->value->type;

if (control()->is_void(paramType)) {
continue;
}

parameterTypes.push_back(paramType);
}

Expand Down