diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index bd0846cb..fe5addfb 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -180,6 +180,12 @@ def Cxx_SubscriptOp : Cxx_Op<"subscript"> { let results = (outs Cxx_PointerType:$result); } +def Cxx_MemberOp : Cxx_Op<"member"> { + let arguments = (ins Cxx_PointerType:$base, I32Prop:$member_index); + + let results = (outs Cxx_PointerType:$result); +} + def Cxx_AddressOfOp : Cxx_Op<"addressof"> { let arguments = (ins FlatSymbolRefAttr:$sym_name); diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 7f00e08e..2e2227d3 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -601,15 +602,35 @@ auto Codegen::ExpressionVisitor::operator()(SpliceMemberExpressionAST* ast) auto Codegen::ExpressionVisitor::operator()(MemberExpressionAST* ast) -> ExpressionResult { + if (auto field = symbol_cast(ast->symbol); + field && !field->isStatic()) { + // todo: introduce ClassLayout to avoid linear searches and support c++ + // class layout + int fieldIndex = 0; + auto classSymbol = symbol_cast(field->enclosingSymbol()); + for (auto member : classSymbol->scope()->symbols()) { + auto f = symbol_cast(member); + if (!f) continue; + if (f->isStatic()) continue; + if (member == field) break; + ++fieldIndex; + } + + auto baseExpressionResult = gen.expression(ast->baseExpression); + + auto loc = gen.getLocation(ast->unqualifiedId->firstSourceLocation()); + + auto resultType = gen.convertType(control()->add_pointer(ast->type)); + + auto op = gen.builder_.create( + loc, resultType, baseExpressionResult.value, fieldIndex); + + return {op}; + } + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); -#if false - auto baseExpressionResult = gen.expression(ast->baseExpression); - auto nestedNameSpecifierResult = gen.nestedNameSpecifier(ast->nestedNameSpecifier); - auto unqualifiedIdResult = gen(ast->unqualifiedId); -#endif - return {op}; } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index c738f2e6..e08f15c3 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -53,37 +53,12 @@ class FuncOpLowering : public OpConversionPattern { -> LogicalResult override { auto typeConverter = getTypeConverter(); - auto funcType = op.getFunctionType(); - - SmallVector argumentTypes; - for (auto argType : funcType.getInputs()) { - auto convertedType = typeConverter->convertType(argType); - if (!convertedType) { - return rewriter.notifyMatchFailure( - op, "failed to convert function argument type"); - } - argumentTypes.push_back(convertedType); + if (failed(convertFunctionTyype(op, rewriter))) { + return rewriter.notifyMatchFailure(op, "failed to convert function type"); } - SmallVector resultTypes; - for (auto resultType : funcType.getResults()) { - auto convertedType = typeConverter->convertType(resultType); - if (!convertedType) { - return rewriter.notifyMatchFailure( - op, "failed to convert function result type"); - } - - resultTypes.push_back(convertedType); - } - - const auto returnType = resultTypes.empty() - ? LLVM::LLVMVoidType::get(getContext()) - : resultTypes.front(); - - const auto isVarArg = funcType.getVariadic(); - - auto llvmFuncType = - LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg); + auto funcType = op.getFunctionType(); + auto llvmFuncType = typeConverter->convertType(funcType); auto func = rewriter.create(op.getLoc(), op.getSymName(), llvmFuncType); @@ -98,6 +73,29 @@ class FuncOpLowering : public OpConversionPattern { return success(); } + + auto convertFunctionTyype(cxx::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const + -> LogicalResult { + auto type = funcOp.getFunctionType(); + const auto &typeConverter = *getTypeConverter(); + + TypeConverter::SignatureConversion result(type.getInputs().size()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + auto newType = cxx::FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), + newResults, type.getVariadic()); + + rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); + + return success(); + } }; class GlobalOpLowering : public OpConversionPattern { @@ -334,6 +332,45 @@ class SubscriptOpLowering : public OpConversionPattern { const DataLayout &dataLayout_; }; +class MemberOpLowering : public OpConversionPattern { + public: + MemberOpLowering(const TypeConverter &typeConverter, + const DataLayout &dataLayout, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + dataLayout_(dataLayout) {} + + auto matchAndRewrite(cxx::MemberOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto pointerType = cast(op.getBase().getType()); + auto classType = dyn_cast(pointerType.getElementType()); + + auto resultType = typeConverter->convertType(op.getResult().getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert member result type"); + } + + auto elementType = typeConverter->convertType(classType); + + SmallVector indices; + indices.push_back(0); + indices.push_back(adaptor.getMemberIndex()); + + rewriter.replaceOpWithNewOp(op, resultType, elementType, + adaptor.getBase(), indices); + + return success(); + } + + private: + const DataLayout &dataLayout_; +}; + class BoolConstantOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1329,7 +1366,8 @@ void CxxToLLVMLoweringPass::runOnOperation() { DataLayout dataLayout{module}; patterns.insert(typeConverter, dataLayout, context); + SubscriptOpLowering, MemberOpLowering>(typeConverter, + dataLayout, context); // cast operations patterns.insert