Skip to content

Commit 68f981c

Browse files
committed
Add GlobalOp and AddressOfOp with corresponding lowering patterns
Signed-off-by: Roberto Raggi <[email protected]>
1 parent 22bd99d commit 68f981c

File tree

5 files changed

+102
-6
lines changed

5 files changed

+102
-6
lines changed

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
113113
let hasCustomAssemblyFormat = 1;
114114
}
115115

116+
def Cxx_GlobalOp : Cxx_Op<"global"> {
117+
let arguments = (ins
118+
TypeAttrOf<AnyType>:$global_type
119+
, UnitAttr:$constant
120+
, SymbolNameAttr:$sym_name
121+
, OptionalAttr<AnyAttr>:$value
122+
);
123+
}
124+
116125
def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
117126
let arguments = (ins Variadic<AnyType>:$input);
118127

@@ -121,8 +130,6 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {
121130
let extraClassDeclaration = [{
122131
bool hasOperand() { return getNumOperands() != 0; }
123132
}];
124-
125-
let hasVerifier = 0;
126133
}
127134

128135
def Cxx_CallOp : Cxx_Op<"call"> {
@@ -160,6 +167,12 @@ def Cxx_SubscriptOp : Cxx_Op<"subscript"> {
160167
let results = (outs Cxx_PointerType:$result);
161168
}
162169

170+
def Cxx_AddressOfOp : Cxx_Op<"addressof"> {
171+
let arguments = (ins FlatSymbolRefAttr:$sym_name);
172+
173+
let results = (outs AnyType:$result);
174+
}
175+
163176
def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [
164177
Pure
165178
]> {

src/mlir/cxx/mlir/codegen.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ auto Codegen::newBlock() -> mlir::Block* {
5656
return newBlock;
5757
}
5858

59+
auto Codegen::newUniqueSymbolName(std::string_view prefix) -> std::string {
60+
auto& uniqueName = uniqueSymbolNames_[prefix];
61+
if (uniqueName == 0) {
62+
uniqueName = 1;
63+
return std::format("{}{}", prefix, uniqueName);
64+
}
65+
return std::format("{}{}", prefix, ++uniqueName);
66+
}
67+
5968
void Codegen::branch(mlir::Location loc, mlir::Block* block,
6069
mlir::ValueRange operands) {
6170
if (currentBlockMightHaveTerminator()) return;

src/mlir/cxx/mlir/codegen.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@ class Codegen {
269269
-> std::optional<mlir::Value>;
270270

271271
[[nodiscard]] auto newBlock() -> mlir::Block*;
272+
273+
[[nodiscard]] auto newUniqueSymbolName(std::string_view prefix)
274+
-> std::string;
275+
272276
void branch(mlir::Location loc, mlir::Block* block,
273277
mlir::ValueRange operands = {});
274278

@@ -315,6 +319,8 @@ class Codegen {
315319
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
316320
std::unordered_map<Symbol*, mlir::Value> locals_;
317321
std::unordered_map<FunctionSymbol*, mlir::cxx::FuncOp> funcOps_;
322+
std::unordered_map<std::string_view, int> uniqueSymbolNames_;
323+
std::unordered_map<const StringLiteral*, mlir::StringAttr> stringLiterals_;
318324
Loop loop_;
319325
int count_ = 0;
320326
};

src/mlir/cxx/mlir/codegen_expressions.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,31 @@ auto Codegen::ExpressionVisitor::operator()(NullptrLiteralExpressionAST* ast)
254254

255255
auto Codegen::ExpressionVisitor::operator()(StringLiteralExpressionAST* ast)
256256
-> ExpressionResult {
257+
auto loc = gen.getLocation(ast->literalLoc);
258+
auto type = gen.convertType(ast->type);
259+
auto resultType = mlir::cxx::PointerType::get(type.getContext(), type);
260+
261+
auto it = gen.stringLiterals_.find(ast->literal);
262+
if (it == gen.stringLiterals_.end()) {
263+
// todo: clean up
264+
std::string str(ast->literal->stringValue());
265+
str.push_back('\0');
266+
267+
auto initializer = gen.builder_.getStringAttr(str);
268+
269+
// todo: generate unique name for the global
270+
auto name = gen.builder_.getStringAttr(gen.newUniqueSymbolName(".str"));
271+
272+
auto x = mlir::OpBuilder(gen.module_->getContext());
273+
x.setInsertionPointToEnd(gen.module_.getBody());
274+
x.create<mlir::cxx::GlobalOp>(loc, type, true, name, initializer);
275+
276+
it = gen.stringLiterals_.insert_or_assign(ast->literal, name).first;
277+
}
278+
257279
auto op =
258-
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
280+
gen.builder_.create<mlir::cxx::AddressOfOp>(loc, resultType, it->second);
281+
259282
return {op};
260283
}
261284

@@ -1024,6 +1047,11 @@ auto Codegen::ExpressionVisitor::operator()(ImplicitCastExpressionAST* ast)
10241047
return {op};
10251048
}
10261049

1050+
case ImplicitCastKind::kQualificationConversion: {
1051+
auto expressionResult = gen.expression(ast->expression);
1052+
return expressionResult;
1053+
}
1054+
10271055
default:
10281056
break;
10291057

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,25 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
100100
}
101101
};
102102

103+
class GlobalOpLowering : public OpConversionPattern<cxx::GlobalOp> {
104+
public:
105+
using OpConversionPattern::OpConversionPattern;
106+
107+
auto matchAndRewrite(cxx::GlobalOp op, OpAdaptor adaptor,
108+
ConversionPatternRewriter &rewriter) const
109+
-> LogicalResult override {
110+
auto typeConverter = getTypeConverter();
111+
112+
auto elementType = getTypeConverter()->convertType(op.getGlobalType());
113+
114+
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
115+
op, elementType, op.getConstant(), LLVM::linkage::Linkage::Private,
116+
op.getSymName(), adaptor.getValue().value());
117+
118+
return success();
119+
}
120+
};
121+
103122
class ReturnOpLowering : public OpConversionPattern<cxx::ReturnOp> {
104123
public:
105124
using OpConversionPattern::OpConversionPattern;
@@ -145,6 +164,28 @@ class CallOpLowering : public OpConversionPattern<cxx::CallOp> {
145164
}
146165
};
147166

167+
class AddressOfOpLowering : public OpConversionPattern<cxx::AddressOfOp> {
168+
public:
169+
using OpConversionPattern::OpConversionPattern;
170+
171+
auto matchAndRewrite(cxx::AddressOfOp op, OpAdaptor adaptor,
172+
ConversionPatternRewriter &rewriter) const
173+
-> LogicalResult override {
174+
auto typeConverter = getTypeConverter();
175+
176+
auto resultType = typeConverter->convertType(op.getType());
177+
if (!resultType) {
178+
return rewriter.notifyMatchFailure(op,
179+
"failed to convert address of type");
180+
}
181+
182+
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, resultType,
183+
adaptor.getSymName());
184+
185+
return success();
186+
}
187+
};
188+
148189
class AllocaOpLowering : public OpConversionPattern<cxx::AllocaOp> {
149190
public:
150191
AllocaOpLowering(const TypeConverter &typeConverter,
@@ -461,7 +502,6 @@ class ArrayToPointerOpLowering
461502

462503
SmallVector<LLVM::GEPArg> indices;
463504

464-
indices.push_back(0);
465505
indices.push_back(0);
466506

467507
auto resultType = LLVM::LLVMPointerType::get(context);
@@ -1254,8 +1294,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
12541294
RewritePatternSet patterns(context);
12551295

12561296
// function operations
1257-
patterns.insert<FuncOpLowering, ReturnOpLowering, CallOpLowering>(
1258-
typeConverter, context);
1297+
patterns.insert<FuncOpLowering, GlobalOpLowering, ReturnOpLowering,
1298+
CallOpLowering, AddressOfOpLowering>(typeConverter, context);
12591299

12601300
// memory operations
12611301
DataLayout dataLayout{module};

0 commit comments

Comments
 (0)