Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ def Cxx_SubscriptOp : Cxx_Op<"subscript"> {
let results = (outs Cxx_PointerType:$result);
}

def Cxx_MemberOp : Cxx_Op<"member"> {
let arguments = (ins Cxx_PointerType:$base, I32Prop:$member_index);

let results = (outs Cxx_PointerType:$result);
}

def Cxx_AddressOfOp : Cxx_Op<"addressof"> {
let arguments = (ins FlatSymbolRefAttr:$sym_name);

Expand Down
33 changes: 27 additions & 6 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cxx/control.h>
#include <cxx/literals.h>
#include <cxx/memory_layout.h>
#include <cxx/scope.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
Expand Down Expand Up @@ -601,15 +602,35 @@ auto Codegen::ExpressionVisitor::operator()(SpliceMemberExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(MemberExpressionAST* ast)
-> ExpressionResult {
if (auto field = symbol_cast<FieldSymbol>(ast->symbol);
field && !field->isStatic()) {
// todo: introduce ClassLayout to avoid linear searches and support c++
// class layout
int fieldIndex = 0;
auto classSymbol = symbol_cast<ClassSymbol>(field->enclosingSymbol());
for (auto member : classSymbol->scope()->symbols()) {
auto f = symbol_cast<FieldSymbol>(member);
if (!f) continue;
if (f->isStatic()) continue;
if (member == field) break;
++fieldIndex;
}

auto baseExpressionResult = gen.expression(ast->baseExpression);

auto loc = gen.getLocation(ast->unqualifiedId->firstSourceLocation());

auto resultType = gen.convertType(control()->add_pointer(ast->type));

auto op = gen.builder_.create<mlir::cxx::MemberOp>(
loc, resultType, baseExpressionResult.value, fieldIndex);

return {op};
}

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto baseExpressionResult = gen.expression(ast->baseExpression);
auto nestedNameSpecifierResult = gen.nestedNameSpecifier(ast->nestedNameSpecifier);
auto unqualifiedIdResult = gen(ast->unqualifiedId);
#endif

return {op};
}

Expand Down
98 changes: 68 additions & 30 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,12 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
-> LogicalResult override {
auto typeConverter = getTypeConverter();

auto funcType = op.getFunctionType();

SmallVector<Type> argumentTypes;
for (auto argType : funcType.getInputs()) {
auto convertedType = typeConverter->convertType(argType);
if (!convertedType) {
return rewriter.notifyMatchFailure(
op, "failed to convert function argument type");
}
argumentTypes.push_back(convertedType);
if (failed(convertFunctionTyype(op, rewriter))) {
return rewriter.notifyMatchFailure(op, "failed to convert function type");
}

SmallVector<Type> resultTypes;
for (auto resultType : funcType.getResults()) {
auto convertedType = typeConverter->convertType(resultType);
if (!convertedType) {
return rewriter.notifyMatchFailure(
op, "failed to convert function result type");
}

resultTypes.push_back(convertedType);
}

const auto returnType = resultTypes.empty()
? LLVM::LLVMVoidType::get(getContext())
: resultTypes.front();

const auto isVarArg = funcType.getVariadic();

auto llvmFuncType =
LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg);
auto funcType = op.getFunctionType();
auto llvmFuncType = typeConverter->convertType(funcType);

auto func = rewriter.create<LLVM::LLVMFuncOp>(op.getLoc(), op.getSymName(),
llvmFuncType);
Expand All @@ -98,6 +73,29 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {

return success();
}

auto convertFunctionTyype(cxx::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const
-> LogicalResult {
auto type = funcOp.getFunctionType();
const auto &typeConverter = *getTypeConverter();

TypeConverter::SignatureConversion result(type.getInputs().size());
SmallVector<Type, 1> newResults;
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
typeConverter, &result)))
return failure();

auto newType = cxx::FunctionType::get(rewriter.getContext(),
result.getConvertedTypes(),
newResults, type.getVariadic());

rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });

return success();
}
};

class GlobalOpLowering : public OpConversionPattern<cxx::GlobalOp> {
Expand Down Expand Up @@ -334,6 +332,45 @@ class SubscriptOpLowering : public OpConversionPattern<cxx::SubscriptOp> {
const DataLayout &dataLayout_;
};

class MemberOpLowering : public OpConversionPattern<cxx::MemberOp> {
public:
MemberOpLowering(const TypeConverter &typeConverter,
const DataLayout &dataLayout, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern<cxx::MemberOp>(typeConverter, context, benefit),
dataLayout_(dataLayout) {}

auto matchAndRewrite(cxx::MemberOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto pointerType = cast<cxx::PointerType>(op.getBase().getType());
auto classType = dyn_cast<cxx::ClassType>(pointerType.getElementType());

auto resultType = typeConverter->convertType(op.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert member result type");
}

auto elementType = typeConverter->convertType(classType);

SmallVector<LLVM::GEPArg> indices;
indices.push_back(0);
indices.push_back(adaptor.getMemberIndex());

rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, resultType, elementType,
adaptor.getBase(), indices);

return success();
}

private:
const DataLayout &dataLayout_;
};

class BoolConstantOpLowering : public OpConversionPattern<cxx::BoolConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -1329,7 +1366,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
DataLayout dataLayout{module};

patterns.insert<AllocaOpLowering, LoadOpLowering, StoreOpLowering,
SubscriptOpLowering>(typeConverter, dataLayout, context);
SubscriptOpLowering, MemberOpLowering>(typeConverter,
dataLayout, context);

// cast operations
patterns.insert<IntToBoolOpLowering, BoolToIntOpLowering,
Expand Down
Loading