Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/frontend/cxx/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#ifdef CXX_WITH_MLIR
#include <cxx/mlir/codegen.h>
#include <cxx/mlir/cxx_dialect.h>
#include <cxx/mlir/cxx_dialect_conversions.h>
#endif

#include <format>
Expand Down Expand Up @@ -376,6 +377,11 @@ auto runOnFile(const CLI& cli, const std::string& fileName) -> bool {

auto ir = codegen(unit.ast());

if (failed(lowerToMLIR(ir.module))) {
std::cerr << "cxx: failed to lower C++ AST to MLIR" << std::endl;
return false;
}

mlir::OpPrintingFlags flags;
if (cli.opt_g) {
flags.enableDebugInfo(true, true);
Expand Down
3 changes: 3 additions & 0 deletions src/mlir/cxx/mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(SOURCES
codegen_units.cc
convert_type.cc
cxx_dialect.cc
cxx_dialect_conversions.cc
)

add_library(cxx-mlir ${SOURCES})
Expand All @@ -31,6 +32,8 @@ target_link_libraries(cxx-mlir PUBLIC
MLIRFuncDialect
MLIRControlFlowDialect
MLIRSCFDialect
MLIRPass
MLIRTransforms
)

target_compile_definitions(cxx-mlir PUBLIC CXX_WITH_MLIR)
Expand Down
18 changes: 14 additions & 4 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,24 @@ def Cxx_ArrayType : Cxx_Type<"Array", "array"> {
let assemblyFormat = "`<` $elementType `,` $size `>`";
}

def Cxx_ClassType : Cxx_Type<"Class", "class"> {
def Cxx_ClassType : Cxx_Type<"Class", "class", [MutableType]> {

let storageClass = "ClassTypeStorage";
let genStorageClass = 0;

let skipDefaultBuilders = 1;
let hasCustomAssemblyFormat = 1;

let parameters = (ins
StringRefParameter<"class name", [{ "" }]>:$name,
OptionalArrayRefParameter<"mlir::Type">:$body
);

let assemblyFormat = "`<` $name `(` $body `)` `>`";
let extraClassDeclaration = [{
static auto getNamed(MLIRContext *context, StringRef name) -> ClassType;
auto setBody(ArrayRef<Type> types) -> LogicalResult;
}];

}

// ops
Expand All @@ -97,8 +108,7 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];

let extraClassDeclaration =
[{
let extraClassDeclaration = [{
auto getArgumentTypes() -> ArrayRef<Type> { return getFunctionType().getInputs(); }
auto getResultTypes() -> ArrayRef<Type> { return getFunctionType().getResults(); }
auto getCallableRegion() -> Region* { return &getBody(); }
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class Codegen {
TranslationUnit* unit_ = nullptr;
mlir::Block* exitBlock_ = nullptr;
mlir::cxx::AllocaOp exitValue_;
std::unordered_map<ClassSymbol*, std::string> classNames_;
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
int count_ = 0;
};

Expand Down
8 changes: 5 additions & 3 deletions src/mlir/cxx/mlir/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {

if (auto it = gen.classNames_.find(classSymbol);
it != gen.classNames_.end()) {
return mlir::cxx::ClassType::get(ctx, it->second, {});
return it->second;
}

auto name = to_string(classSymbol->name());
Expand All @@ -277,7 +277,9 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {
name = std::format("$class_{}", loc.index());
}

gen.classNames_[classSymbol] = name;
mlir::cxx::ClassType classType = mlir::cxx::ClassType::getNamed(ctx, name);

gen.classNames_[classSymbol] = classType;

// todo: layout of parent classes, anonymous nested fields, etc.

Expand All @@ -291,7 +293,7 @@ auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {
memberTypes.push_back(memberType);
}

auto classType = mlir::cxx::ClassType::get(ctx, name, memberTypes);
classType.setBody(memberTypes);

return classType;
}
Expand Down
110 changes: 94 additions & 16 deletions src/mlir/cxx/mlir/cxx_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,73 @@

namespace mlir::cxx {

struct detail::ClassTypeStorage : public TypeStorage {
public:
using KeyTy = StringRef;

explicit ClassTypeStorage(const KeyTy &key) : name_(key) {}

auto getName() -> StringRef const { return name_; }
auto getBody() const -> ArrayRef<Type> { return body_; }

auto operator==(const KeyTy &key) const -> bool { return name_ == key; };

static auto hashKey(const KeyTy &key) -> llvm::hash_code {
return llvm::hash_value(key);
}

static ClassTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<ClassTypeStorage>())
ClassTypeStorage(allocator.copyInto(key));
}

auto mutate(TypeStorageAllocator &allocator, ArrayRef<Type> body)
-> LogicalResult {
if (isInitialized_) return success(body == getBody());

isInitialized_ = true;
body_ = allocator.copyInto(body);

return success();
}

private:
StringRef name_;
ArrayRef<Type> body_;
bool isInitialized_ = false;
};

namespace {

struct CxxGenerateAliases : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;

auto getAlias(Type type, raw_ostream &os) const -> AliasResult override {
if (auto intType = mlir::dyn_cast<mlir::cxx::IntegerType>(type)) {
if (auto intType = dyn_cast<IntegerType>(type)) {
os << 'i' << intType.getWidth() << (intType.getIsSigned() ? 's' : 'u');
return AliasResult::FinalAlias;
}

if (auto floatType = mlir::dyn_cast<mlir::cxx::FloatType>(type)) {
if (auto floatType = dyn_cast<FloatType>(type)) {
os << 'f' << floatType.getWidth();
return AliasResult::FinalAlias;
}

if (auto classType = mlir::dyn_cast<mlir::cxx::ClassType>(type)) {
if (auto classType = dyn_cast<ClassType>(type)) {
if (!classType.getBody().empty()) {
os << "class_" << classType.getName();
return AliasResult::FinalAlias;
}
}

if (mlir::isa<VoidType>(type)) {
if (isa<VoidType>(type)) {
os << "void";
return AliasResult::FinalAlias;
}

if (mlir::isa<BoolType>(type)) {
if (isa<BoolType>(type)) {
os << "bool";
return AliasResult::FinalAlias;
}
Expand All @@ -85,32 +122,73 @@ void CxxDialect::initialize() {
addInterface<CxxGenerateAliases>();
}

void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

void FuncOp::print(mlir::OpAsmPrinter &p) {
mlir::function_interface_impl::printFunctionOp(
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}

auto FuncOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result)
-> mlir::ParseResult {
auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult {
auto funcTypeBuilder =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
[](Builder &builder, llvm::ArrayRef<Type> argTypes,
ArrayRef<Type> results, function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return mlir::function_interface_impl::parseFunctionOp(
return function_interface_impl::parseFunctionOp(
parser, result, false, getFunctionTypeAttrName(result.name),
funcTypeBuilder, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
}

auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType {
return Base::get(context, name);
}

auto ClassType::setBody(llvm::ArrayRef<Type> body) -> LogicalResult {
Base::mutate(body);
}

void ClassType::print(AsmPrinter &p) const {
FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;

p << "<";
cyclicPrint = p.tryStartCyclicPrint(*this);

p << '"';
llvm::printEscapedString(getName(), p.getStream());
p << '"';

if (failed(cyclicPrint)) {
p << '>';
return;
}

p << ", ";

p << '(';
llvm::interleaveComma(getBody(), p.getStream(),
[&](Type subtype) { p << subtype; });
p << ')';

p << '>';
}

auto ClassType::parse(AsmParser &parser) -> Type {
// todo: implement parsing for ClassType
return {};
}

auto ClassType::getName() const -> StringRef { return getImpl()->getName(); }

auto ClassType::getBody() const -> ArrayRef<Type> {
return getImpl()->getBody();
}

} // namespace mlir::cxx

#include <cxx/mlir/CxxOpsDialect.cpp.inc>
Expand Down
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/cxx_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
#pragma GCC diagnostic pop
#endif

namespace mlir::cxx::detail {

struct ClassTypeStorage;

}

#include <cxx/mlir/CxxOpsDialect.h.inc>

#define GET_TYPEDEF_CLASSES
Expand Down
Loading