diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 50162bdf..dbfd0e5b 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -57,12 +57,33 @@ def Cxx_IntegerType : Cxx_Type<"Integer", "int"> { let assemblyFormat = "`<` $width `,` $isSigned `>`"; } +def Cxx_FloatType : Cxx_Type<"Float", "float"> { + let parameters = (ins "unsigned":$width); + + let assemblyFormat = "`<` $width `>`"; +} + def Cxx_PointerType : Cxx_Type<"Pointer", "ptr"> { let parameters = (ins "Type":$elementType); let assemblyFormat = "`<` $elementType `>`"; } +def Cxx_ArrayType : Cxx_Type<"Array", "array"> { + let parameters = (ins "Type":$elementType, "unsigned":$size); + + let assemblyFormat = "`<` $elementType `,` $size `>`"; +} + +def Cxx_ClassType : Cxx_Type<"Class", "class"> { + let parameters = (ins + StringRefParameter<"class name", [{ "" }]>:$name, + OptionalArrayRefParameter<"mlir::Type">:$body + ); + + let assemblyFormat = "`<` $name `(` $body `)` `>`"; +} + // ops def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> { diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 6521857a..49655632 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -23,11 +23,14 @@ #include #include #include +#include #include #include #include #include +#include + namespace mlir::func { class FuncOp; } @@ -271,6 +274,7 @@ class Codegen { TranslationUnit* unit_ = nullptr; mlir::Block* exitBlock_ = nullptr; mlir::cxx::AllocaOp exitValue_; + std::unordered_map classNames_; int count_ = 0; }; diff --git a/src/mlir/cxx/mlir/convert_type.cc b/src/mlir/cxx/mlir/convert_type.cc index 9a6f6b11..75f58df0 100644 --- a/src/mlir/cxx/mlir/convert_type.cc +++ b/src/mlir/cxx/mlir/convert_type.cc @@ -26,10 +26,13 @@ #include #include #include +#include #include #include #include +#include + namespace cxx { struct Codegen::ConvertType { @@ -40,6 +43,7 @@ struct Codegen::ConvertType { auto getExprType() const -> mlir::Type; auto getIntType(const Type* type, bool isSigned) -> mlir::Type; + auto getFloatType(const Type* type) -> mlir::Type; auto operator()(const VoidType* type) -> mlir::Type; auto operator()(const NullptrType* type) -> mlir::Type; @@ -106,6 +110,11 @@ auto Codegen::ConvertType::getIntType(const Type* type, bool isSigned) return gen.builder_.getType(width, isSigned); } +auto Codegen::ConvertType::getFloatType(const Type* type) -> mlir::Type { + return gen.builder_.getType( + memoryLayout()->sizeOf(type).value() * 8); +} + auto Codegen::ConvertType::operator()(const VoidType* type) -> mlir::Type { return gen.builder_.getType(); } @@ -205,16 +214,16 @@ auto Codegen::ConvertType::operator()(const WideCharType* type) -> mlir::Type { } auto Codegen::ConvertType::operator()(const FloatType* type) -> mlir::Type { - return gen.builder_.getF32Type(); + return getFloatType(type); } auto Codegen::ConvertType::operator()(const DoubleType* type) -> mlir::Type { - return gen.builder_.getF64Type(); + return getFloatType(type); } auto Codegen::ConvertType::operator()(const LongDoubleType* type) -> mlir::Type { - return getExprType(); + return getFloatType(type); } auto Codegen::ConvertType::operator()(const QualType* type) -> mlir::Type { @@ -223,12 +232,14 @@ auto Codegen::ConvertType::operator()(const QualType* type) -> mlir::Type { auto Codegen::ConvertType::operator()(const BoundedArrayType* type) -> mlir::Type { - return getExprType(); + auto elementType = gen.convertType(type->elementType()); + return gen.builder_.getType(elementType, type->size()); } auto Codegen::ConvertType::operator()(const UnboundedArrayType* type) -> mlir::Type { - return getExprType(); + auto elementType = gen.convertType(type->elementType()); + return gen.builder_.getType(elementType); } auto Codegen::ConvertType::operator()(const PointerType* type) -> mlir::Type { @@ -251,7 +262,38 @@ auto Codegen::ConvertType::operator()(const FunctionType* type) -> mlir::Type { } auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type { - return getExprType(); + auto classSymbol = type->symbol(); + + auto ctx = gen.builder_.getContext(); + + if (auto it = gen.classNames_.find(classSymbol); + it != gen.classNames_.end()) { + return mlir::cxx::ClassType::get(ctx, it->second, {}); + } + + auto name = to_string(classSymbol->name()); + if (name.empty()) { + auto loc = type->symbol()->location(); + name = std::format("$class_{}", loc.index()); + } + + gen.classNames_[classSymbol] = name; + + // todo: layout of parent classes, anonymous nested fields, etc. + + std::vector memberTypes; + + for (auto member : classSymbol->scope()->symbols()) { + auto field = symbol_cast(member); + if (!field) continue; + if (field->isStatic()) continue; + auto memberType = gen.convertType(member->type()); + memberTypes.push_back(memberType); + } + + auto classType = mlir::cxx::ClassType::get(ctx, name, memberTypes); + + return classType; } auto Codegen::ConvertType::operator()(const EnumType* type) -> mlir::Type { diff --git a/src/mlir/cxx/mlir/cxx_dialect.cc b/src/mlir/cxx/mlir/cxx_dialect.cc index 19e92dcb..d6174697 100644 --- a/src/mlir/cxx/mlir/cxx_dialect.cc +++ b/src/mlir/cxx/mlir/cxx_dialect.cc @@ -44,6 +44,18 @@ struct CxxGenerateAliases : public OpAsmDialectInterface { return AliasResult::FinalAlias; } + if (auto floatType = mlir::dyn_cast(type)) { + os << 'f' << floatType.getWidth(); + return AliasResult::FinalAlias; + } + + if (auto classType = mlir::dyn_cast(type)) { + if (!classType.getBody().empty()) { + os << "class_" << classType.getName(); + return AliasResult::FinalAlias; + } + } + if (mlir::isa(type)) { os << "void"; return AliasResult::FinalAlias;