Skip to content

Commit ba94f5d

Browse files
committed
Add MemberOp and its lowering pattern for member access
1 parent 646b231 commit ba94f5d

File tree

3 files changed

+101
-36
lines changed

3 files changed

+101
-36
lines changed

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ def Cxx_SubscriptOp : Cxx_Op<"subscript"> {
180180
let results = (outs Cxx_PointerType:$result);
181181
}
182182

183+
def Cxx_MemberOp : Cxx_Op<"member"> {
184+
let arguments = (ins Cxx_PointerType:$base, I32Prop:$member_index);
185+
186+
let results = (outs Cxx_PointerType:$result);
187+
}
188+
183189
def Cxx_AddressOfOp : Cxx_Op<"addressof"> {
184190
let arguments = (ins FlatSymbolRefAttr:$sym_name);
185191

src/mlir/cxx/mlir/codegen_expressions.cc

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cxx/control.h>
2727
#include <cxx/literals.h>
2828
#include <cxx/memory_layout.h>
29+
#include <cxx/scope.h>
2930
#include <cxx/symbols.h>
3031
#include <cxx/translation_unit.h>
3132
#include <cxx/types.h>
@@ -601,15 +602,35 @@ auto Codegen::ExpressionVisitor::operator()(SpliceMemberExpressionAST* ast)
601602

602603
auto Codegen::ExpressionVisitor::operator()(MemberExpressionAST* ast)
603604
-> ExpressionResult {
605+
if (auto field = symbol_cast<FieldSymbol>(ast->symbol);
606+
field && !field->isStatic()) {
607+
// todo: introduce ClassLayout to avoid linear searches and support c++
608+
// class layout
609+
int fieldIndex = 0;
610+
auto classSymbol = symbol_cast<ClassSymbol>(field->enclosingSymbol());
611+
for (auto member : classSymbol->scope()->symbols()) {
612+
auto f = symbol_cast<FieldSymbol>(member);
613+
if (!f) continue;
614+
if (f->isStatic()) continue;
615+
if (member == field) break;
616+
++fieldIndex;
617+
}
618+
619+
auto baseExpressionResult = gen.expression(ast->baseExpression);
620+
621+
auto loc = gen.getLocation(ast->unqualifiedId->firstSourceLocation());
622+
623+
auto resultType = gen.convertType(control()->add_pointer(ast->type));
624+
625+
auto op = gen.builder_.create<mlir::cxx::MemberOp>(
626+
loc, resultType, baseExpressionResult.value, fieldIndex);
627+
628+
return {op};
629+
}
630+
604631
auto op =
605632
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));
606633

607-
#if false
608-
auto baseExpressionResult = gen.expression(ast->baseExpression);
609-
auto nestedNameSpecifierResult = gen.nestedNameSpecifier(ast->nestedNameSpecifier);
610-
auto unqualifiedIdResult = gen(ast->unqualifiedId);
611-
#endif
612-
613634
return {op};
614635
}
615636

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -53,37 +53,12 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
5353
-> LogicalResult override {
5454
auto typeConverter = getTypeConverter();
5555

56-
auto funcType = op.getFunctionType();
57-
58-
SmallVector<Type> argumentTypes;
59-
for (auto argType : funcType.getInputs()) {
60-
auto convertedType = typeConverter->convertType(argType);
61-
if (!convertedType) {
62-
return rewriter.notifyMatchFailure(
63-
op, "failed to convert function argument type");
64-
}
65-
argumentTypes.push_back(convertedType);
56+
if (failed(convertFunctionTyype(op, rewriter))) {
57+
return rewriter.notifyMatchFailure(op, "failed to convert function type");
6658
}
6759

68-
SmallVector<Type> resultTypes;
69-
for (auto resultType : funcType.getResults()) {
70-
auto convertedType = typeConverter->convertType(resultType);
71-
if (!convertedType) {
72-
return rewriter.notifyMatchFailure(
73-
op, "failed to convert function result type");
74-
}
75-
76-
resultTypes.push_back(convertedType);
77-
}
78-
79-
const auto returnType = resultTypes.empty()
80-
? LLVM::LLVMVoidType::get(getContext())
81-
: resultTypes.front();
82-
83-
const auto isVarArg = funcType.getVariadic();
84-
85-
auto llvmFuncType =
86-
LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg);
60+
auto funcType = op.getFunctionType();
61+
auto llvmFuncType = typeConverter->convertType(funcType);
8762

8863
auto func = rewriter.create<LLVM::LLVMFuncOp>(op.getLoc(), op.getSymName(),
8964
llvmFuncType);
@@ -98,6 +73,29 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
9873

9974
return success();
10075
}
76+
77+
auto convertFunctionTyype(cxx::FuncOp funcOp,
78+
ConversionPatternRewriter &rewriter) const
79+
-> LogicalResult {
80+
auto type = funcOp.getFunctionType();
81+
const auto &typeConverter = *getTypeConverter();
82+
83+
TypeConverter::SignatureConversion result(type.getInputs().size());
84+
SmallVector<Type, 1> newResults;
85+
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
86+
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
87+
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
88+
typeConverter, &result)))
89+
return failure();
90+
91+
auto newType = cxx::FunctionType::get(rewriter.getContext(),
92+
result.getConvertedTypes(),
93+
newResults, type.getVariadic());
94+
95+
rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
96+
97+
return success();
98+
}
10199
};
102100

103101
class GlobalOpLowering : public OpConversionPattern<cxx::GlobalOp> {
@@ -334,6 +332,45 @@ class SubscriptOpLowering : public OpConversionPattern<cxx::SubscriptOp> {
334332
const DataLayout &dataLayout_;
335333
};
336334

335+
class MemberOpLowering : public OpConversionPattern<cxx::MemberOp> {
336+
public:
337+
MemberOpLowering(const TypeConverter &typeConverter,
338+
const DataLayout &dataLayout, MLIRContext *context,
339+
PatternBenefit benefit = 1)
340+
: OpConversionPattern<cxx::MemberOp>(typeConverter, context, benefit),
341+
dataLayout_(dataLayout) {}
342+
343+
auto matchAndRewrite(cxx::MemberOp op, OpAdaptor adaptor,
344+
ConversionPatternRewriter &rewriter) const
345+
-> LogicalResult override {
346+
auto typeConverter = getTypeConverter();
347+
auto context = getContext();
348+
349+
auto pointerType = cast<cxx::PointerType>(op.getBase().getType());
350+
auto classType = dyn_cast<cxx::ClassType>(pointerType.getElementType());
351+
352+
auto resultType = typeConverter->convertType(op.getResult().getType());
353+
if (!resultType) {
354+
return rewriter.notifyMatchFailure(
355+
op, "failed to convert member result type");
356+
}
357+
358+
auto elementType = typeConverter->convertType(classType);
359+
360+
SmallVector<LLVM::GEPArg> indices;
361+
indices.push_back(0);
362+
indices.push_back(adaptor.getMemberIndex());
363+
364+
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, resultType, elementType,
365+
adaptor.getBase(), indices);
366+
367+
return success();
368+
}
369+
370+
private:
371+
const DataLayout &dataLayout_;
372+
};
373+
337374
class BoolConstantOpLowering : public OpConversionPattern<cxx::BoolConstantOp> {
338375
public:
339376
using OpConversionPattern::OpConversionPattern;
@@ -1329,7 +1366,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
13291366
DataLayout dataLayout{module};
13301367

13311368
patterns.insert<AllocaOpLowering, LoadOpLowering, StoreOpLowering,
1332-
SubscriptOpLowering>(typeConverter, dataLayout, context);
1369+
SubscriptOpLowering, MemberOpLowering>(typeConverter,
1370+
dataLayout, context);
13331371

13341372
// cast operations
13351373
patterns.insert<IntToBoolOpLowering, BoolToIntOpLowering,

0 commit comments

Comments
 (0)