Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
let hasCustomAssemblyFormat = 1;
}

def Cxx_GlobalOp : Cxx_Op<"global"> {
let arguments = (ins
TypeAttrOf<AnyType>:$global_type
, UnitAttr:$constant
, SymbolNameAttr:$sym_name
, OptionalAttr<AnyAttr>:$value
);
}

def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
let arguments = (ins Variadic<AnyType>:$input);

Expand All @@ -121,8 +130,6 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
let extraClassDeclaration = [{
bool hasOperand() { return getNumOperands() != 0; }
}];

let hasVerifier = 0;
}

def Cxx_CallOp : Cxx_Op<"call"> {
Expand Down Expand Up @@ -160,6 +167,12 @@ def Cxx_SubscriptOp : Cxx_Op<"subscript"> {
let results = (outs Cxx_PointerType:$result);
}

def Cxx_AddressOfOp : Cxx_Op<"addressof"> {
let arguments = (ins FlatSymbolRefAttr:$sym_name);

let results = (outs AnyType:$result);
}

def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [
Pure
]> {
Expand Down
9 changes: 9 additions & 0 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ auto Codegen::newBlock() -> mlir::Block* {
return newBlock;
}

auto Codegen::newUniqueSymbolName(std::string_view prefix) -> std::string {
auto& uniqueName = uniqueSymbolNames_[prefix];
if (uniqueName == 0) {
uniqueName = 1;
return std::format("{}{}", prefix, uniqueName);
}
return std::format("{}{}", prefix, ++uniqueName);
}

void Codegen::branch(mlir::Location loc, mlir::Block* block,
mlir::ValueRange operands) {
if (currentBlockMightHaveTerminator()) return;
Expand Down
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ class Codegen {
-> std::optional<mlir::Value>;

[[nodiscard]] auto newBlock() -> mlir::Block*;

[[nodiscard]] auto newUniqueSymbolName(std::string_view prefix)
-> std::string;

void branch(mlir::Location loc, mlir::Block* block,
mlir::ValueRange operands = {});

Expand Down Expand Up @@ -315,6 +319,8 @@ class Codegen {
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
std::unordered_map<Symbol*, mlir::Value> locals_;
std::unordered_map<FunctionSymbol*, mlir::cxx::FuncOp> funcOps_;
std::unordered_map<std::string_view, int> uniqueSymbolNames_;
std::unordered_map<const StringLiteral*, mlir::StringAttr> stringLiterals_;
Loop loop_;
int count_ = 0;
};
Expand Down
30 changes: 29 additions & 1 deletion src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,31 @@ auto Codegen::ExpressionVisitor::operator()(NullptrLiteralExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(StringLiteralExpressionAST* ast)
-> ExpressionResult {
auto loc = gen.getLocation(ast->literalLoc);
auto type = gen.convertType(ast->type);
auto resultType = mlir::cxx::PointerType::get(type.getContext(), type);

auto it = gen.stringLiterals_.find(ast->literal);
if (it == gen.stringLiterals_.end()) {
// todo: clean up
std::string str(ast->literal->stringValue());
str.push_back('\0');

auto initializer = gen.builder_.getStringAttr(str);

// todo: generate unique name for the global
auto name = gen.builder_.getStringAttr(gen.newUniqueSymbolName(".str"));

auto x = mlir::OpBuilder(gen.module_->getContext());
x.setInsertionPointToEnd(gen.module_.getBody());
x.create<mlir::cxx::GlobalOp>(loc, type, true, name, initializer);

it = gen.stringLiterals_.insert_or_assign(ast->literal, name).first;
}

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
gen.builder_.create<mlir::cxx::AddressOfOp>(loc, resultType, it->second);

return {op};
}

Expand Down Expand Up @@ -1024,6 +1047,11 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast)
return {op};
}

case ImplicitCastKind::kQualificationConversion: {
auto expressionResult = gen.expression(ast->expression);
return expressionResult;
}

default:
break;

Expand Down
46 changes: 43 additions & 3 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,25 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
}
};

class GlobalOpLowering : public OpConversionPattern<cxx::GlobalOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::GlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();

auto elementType = getTypeConverter()->convertType(op.getGlobalType());

rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
op, elementType, op.getConstant(), LLVM::linkage::Linkage::Private,
op.getSymName(), adaptor.getValue().value());

return success();
}
};

class ReturnOpLowering : public OpConversionPattern<cxx::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -145,6 +164,28 @@ class CallOpLowering : public OpConversionPattern<cxx::CallOp> {
}
};

class AddressOfOpLowering : public OpConversionPattern<cxx::AddressOfOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::AddressOfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op,
"failed to convert address of type");
}

rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, resultType,
adaptor.getSymName());

return success();
}
};

class AllocaOpLowering : public OpConversionPattern<cxx::AllocaOp> {
public:
AllocaOpLowering(const TypeConverter &typeConverter,
Expand Down Expand Up @@ -461,7 +502,6 @@ class ArrayToPointerOpLowering

SmallVector<LLVM::GEPArg> indices;

indices.push_back(0);
indices.push_back(0);

auto resultType = LLVM::LLVMPointerType::get(context);
Expand Down Expand Up @@ -1254,8 +1294,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
RewritePatternSet patterns(context);

// function operations
patterns.insert<FuncOpLowering, ReturnOpLowering, CallOpLowering>(
typeConverter, context);
patterns.insert<FuncOpLowering, GlobalOpLowering, ReturnOpLowering,
CallOpLowering, AddressOfOpLowering>(typeConverter, context);

// memory operations
DataLayout dataLayout{module};
Expand Down
Loading