diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 8e8a37bd..c5aed44c 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -23,6 +23,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" def Cxx_Dialect : Dialect { let name = "cxx"; @@ -148,7 +149,7 @@ def Cxx_StoreOp : Cxx_Op<"store"> { def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [ Pure ]> { - let arguments = (ins BoolAttr:$value); + let arguments = (ins BoolProp:$value); let results = (outs Cxx_BoolType:$result); } @@ -156,7 +157,7 @@ def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [ def Cxx_IntConstantOp : Cxx_Op<"constant.int", [ Pure ]> { - let arguments = (ins I64Attr:$value); + let arguments = (ins I64Prop:$value); let results = (outs Cxx_IntegerType:$result); } @@ -164,7 +165,7 @@ def Cxx_IntConstantOp : Cxx_Op<"constant.int", [ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [ Pure ]> { - let arguments = (ins F64Attr:$value); + let arguments = (ins TypedAttrInterface:$value); let results = (outs Cxx_FloatType:$result); } @@ -174,7 +175,7 @@ def Cxx_FloatConstantOp : Cxx_Op<"constant.float", [ // def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> { - let arguments = (ins StrAttr:$message); + let arguments = (ins StringProp:$message); let results = (outs Cxx_ExprType:$result); let assemblyFormat = "$message attr-dict `:` type($result)"; let builders = @@ -185,7 +186,7 @@ def Cxx_TodoExprOp : Cxx_Op<"todo.expr"> { } def Cxx_TodoStmtOp : Cxx_Op<"todo.stmt"> { - let arguments = (ins StrAttr:$message); + let arguments = (ins StringProp:$message); let results = (outs); let assemblyFormat = "$message attr-dict"; } \ No newline at end of file diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 4c69018c..054cf55c 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -23,6 +23,7 @@ // cxx #include #include +#include namespace cxx { @@ -133,8 +134,7 @@ auto Codegen::ExpressionVisitor::operator()(CharLiteralExpressionAST* ast) auto loc = gen.getLocation(ast->literalLoc); auto type = gen.convertType(ast->type); - auto value = gen.builder_.getI64IntegerAttr(ast->literal->charValue()); - + auto value = std::int64_t(ast->literal->charValue()); auto op = gen.builder_.create(loc, type, value); return {op}; @@ -145,9 +145,9 @@ auto Codegen::ExpressionVisitor::operator()(BoolLiteralExpressionAST* ast) auto loc = gen.getLocation(ast->literalLoc); auto type = gen.convertType(ast->type); - auto value = gen.builder_.getBoolAttr(ast->isTrue); - auto op = gen.builder_.create(loc, type, value); + auto op = + gen.builder_.create(loc, type, ast->isTrue); return {op}; } @@ -157,7 +157,7 @@ auto Codegen::ExpressionVisitor::operator()(IntLiteralExpressionAST* ast) auto loc = gen.getLocation(ast->literalLoc); auto type = gen.convertType(ast->type); - auto value = gen.builder_.getI64IntegerAttr(ast->literal->integerValue()); + auto value = ast->literal->integerValue(); auto op = gen.builder_.create(loc, type, value); @@ -169,7 +169,25 @@ auto Codegen::ExpressionVisitor::operator()(FloatLiteralExpressionAST* ast) auto loc = gen.getLocation(ast->literalLoc); auto type = gen.convertType(ast->type); - auto value = gen.builder_.getF64FloatAttr(ast->literal->floatValue()); + + mlir::TypedAttr value; + + switch (ast->type->kind()) { + case TypeKind::kFloat: + value = gen.builder_.getF32FloatAttr(ast->literal->floatValue()); + break; + case TypeKind::kDouble: + value = gen.builder_.getF64FloatAttr(ast->literal->floatValue()); + break; + case TypeKind::kLongDouble: + value = gen.builder_.getF64FloatAttr(ast->literal->floatValue()); + break; + default: + // Handle other float types if necessary + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), + "unsupported float type"); + return {op}; + } auto op = gen.builder_.create(loc, type, value); diff --git a/src/mlir/cxx/mlir/convert_type.cc b/src/mlir/cxx/mlir/convert_type.cc index f01b6212..b4537334 100644 --- a/src/mlir/cxx/mlir/convert_type.cc +++ b/src/mlir/cxx/mlir/convert_type.cc @@ -299,12 +299,14 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type { } auto Codegen::ConvertType::operator()(const EnumType* type) -> mlir::Type { - return getExprType(); + if (type->underlyingType()) return gen.convertType(type->underlyingType()); + return gen.builder_.getType(32, true); } auto Codegen::ConvertType::operator()(const ScopedEnumType* type) -> mlir::Type { - return getExprType(); + if (type->underlyingType()) return gen.convertType(type->underlyingType()); + return gen.builder_.getType(32, true); } auto Codegen::ConvertType::operator()(const MemberObjectPointerType* type) diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index 14d67dc2..18c3a882 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -170,6 +170,29 @@ class StoreOpLowering : public OpConversionPattern { } }; +class BoolConstantOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::BoolConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert boolean constant type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + + return success(); + } +}; + class IntConstantOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -186,8 +209,32 @@ class IntConstantOpLowering : public OpConversionPattern { op, "failed to convert integer constant type"); } - auto valueAttr = adaptor.getValueAttr(); - rewriter.replaceOpWithNewOp(op, resultType, valueAttr); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + + return success(); + } +}; + +class FloatConstantOpLowering + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::FloatConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert float constant type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); return success(); } @@ -218,14 +265,60 @@ void CxxToLLVMLoweringPass::runOnOperation() { // set up the type converter LLVMTypeConverter typeConverter{context}; + + typeConverter.addConversion([](cxx::BoolType type) { + // todo: i8/i32 for data and i1 for control flow + return IntegerType::get(type.getContext(), 8); + }); + typeConverter.addConversion([](cxx::IntegerType type) { return IntegerType::get(type.getContext(), type.getWidth()); }); + typeConverter.addConversion([](cxx::FloatType type) -> Type { + auto width = type.getWidth(); + switch (width) { + case 16: + return Float16Type::get(type.getContext()); + case 32: + return Float32Type::get(type.getContext()); + case 64: + return Float64Type::get(type.getContext()); + default: + return {}; + } // switch + }); + typeConverter.addConversion([](cxx::PointerType type) { return LLVM::LLVMPointerType::get(type.getContext()); }); + DenseMap convertedClassTypes; + typeConverter.addConversion([&](cxx::ClassType type) -> Type { + if (auto it = convertedClassTypes.find(type); + it != convertedClassTypes.end()) { + return it->second; + } + + auto structType = + LLVM::LLVMStructType::getIdentified(type.getContext(), type.getName()); + + convertedClassTypes[type] = structType; + + SmallVector fieldTypes; + bool isPacked = false; + + for (auto field : type.getBody()) { + auto convertedFieldType = typeConverter.convertType(field); + // todo: check if the field type was converted successfully + fieldTypes.push_back(convertedFieldType); + } + + structType.setBody(fieldTypes, isPacked); + + return structType; + }); + // set up the conversion patterns ConversionTarget target(*context); @@ -234,8 +327,9 @@ void CxxToLLVMLoweringPass::runOnOperation() { RewritePatternSet patterns(context); patterns.insert( - typeConverter, context); + LoadOpLowering, StoreOpLowering, BoolConstantOpLowering, + IntConstantOpLowering, FloatConstantOpLowering>(typeConverter, + context); populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); diff --git a/src/parser/cxx/decl.cc b/src/parser/cxx/decl.cc index ac9ce1e3..a20904bc 100644 --- a/src/parser/cxx/decl.cc +++ b/src/parser/cxx/decl.cc @@ -215,6 +215,11 @@ struct GetDeclaratorType { if (auto params = ast->parameterDeclarationClause) { for (auto it = params->parameterDeclarationList; it; it = it->next) { auto paramType = it->value->type; + + if (control()->is_void(paramType)) { + continue; + } + parameterTypes.push_back(paramType); }