Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ def Cxx_StoreOp : Cxx_Op<"store"> {
let hasVerifier = 1;
}

def Cxx_SubscriptOp : Cxx_Op<"subscript"> {
let arguments = (ins Cxx_PointerType:$base, AnyType:$index);

let results = (outs Cxx_PointerType:$result);
}

def Cxx_BoolConstantOp : Cxx_Op<"constant.bool", [
Pure
]> {
Expand Down
24 changes: 24 additions & 0 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace cxx {
struct Codegen::DeclarationVisitor {
Codegen& gen;

void allocateLocals(ScopedSymbol* block);

auto operator()(SimpleDeclarationAST* ast) -> DeclarationResult;
auto operator()(AsmDeclarationAST* ast) -> DeclarationResult;
auto operator()(NamespaceAliasDefinitionAST* ast) -> DeclarationResult;
Expand Down Expand Up @@ -140,6 +142,26 @@ auto Codegen::lambdaSpecifier(LambdaSpecifierAST* ast)
return {};
}

void Codegen::DeclarationVisitor::allocateLocals(ScopedSymbol* block) {
for (auto symbol : block->scope()->symbols()) {
if (auto nestedBlock = symbol_cast<BlockSymbol>(symbol)) {
allocateLocals(nestedBlock);
continue;
}

if (auto var = symbol_cast<VariableSymbol>(symbol)) {
if (var->isStatic()) continue;

auto local = gen.findOrCreateLocal(var);
if (!local.has_value()) {
gen.unit_->error(var->location(),
std::format("cannot allocate local variable '{}'",
to_string(var->name())));
}
}
}
}

auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast)
-> DeclarationResult {
if (!gen.function_) {
Expand Down Expand Up @@ -377,6 +399,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
}
}

allocateLocals(functionSymbol);

// generate code for the function body
auto functionBodyResult = gen.functionBody(ast->functionBody);

Expand Down
12 changes: 7 additions & 5 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,15 @@ auto Codegen::ExpressionVisitor::operator()(VaArgExpressionAST* ast)

auto Codegen::ExpressionVisitor::operator()(SubscriptExpressionAST* ast)
-> ExpressionResult {
auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

#if false
auto baseExpressionResult = gen.expression(ast->baseExpression);
auto indexExpressionResult = gen.expression(ast->indexExpression);
#endif

auto loc = gen.getLocation(ast->firstSourceLocation());

auto resultType = gen.convertType(control()->add_pointer(ast->type));

auto op = gen.builder_.create<mlir::cxx::SubscriptOp>(
loc, resultType, baseExpressionResult.value, indexExpressionResult.value);

return {op};
}
Expand Down
49 changes: 47 additions & 2 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ class BoolConstantOpLowering : public OpConversionPattern<cxx::BoolConstantOp> {
}
};

class SubscriptOpLowering : public OpConversionPattern<cxx::SubscriptOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::SubscriptOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto ptrType = dyn_cast_or_null<cxx::PointerType>(op.getBase().getType());

if (!ptrType) {
return rewriter.notifyMatchFailure(
op, "failed to convert subscript operation type");
}

auto arrayType = dyn_cast_or_null<cxx::ArrayType>(ptrType.getElementType());
if (!arrayType) {
return rewriter.notifyMatchFailure(
op, "expected base type of subscript to be an array type");
}

SmallVector<Value> indices;

indices.push_back(adaptor.getIndex());

auto resultType = LLVM::LLVMPointerType::get(context);
auto elementType = typeConverter->convertType(ptrType.getElementType());

rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, resultType, elementType,
adaptor.getBase(), indices);

return success();
}
};

class IntConstantOpLowering : public OpConversionPattern<cxx::IntConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -1093,6 +1130,13 @@ void CxxToLLVMLoweringPass::runOnOperation() {
return LLVM::LLVMPointerType::get(type.getContext());
});

typeConverter.addConversion([&](cxx::ArrayType type) -> Type {
auto elementType = typeConverter.convertType(type.getElementType());
auto size = type.getSize();

return LLVM::LLVMArrayType::get(elementType, size);
});

DenseMap<cxx::ClassType, Type> convertedClassTypes;
typeConverter.addConversion([&](cxx::ClassType type) -> Type {
if (auto it = convertedClassTypes.find(type);
Expand Down Expand Up @@ -1140,8 +1184,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
typeConverter, context);

// memory operations
patterns.insert<AllocaOpLowering, LoadOpLowering, StoreOpLowering>(
typeConverter, context);
patterns.insert<AllocaOpLowering, LoadOpLowering, StoreOpLowering,
SubscriptOpLowering>(typeConverter, context);

// cast operations
patterns
Expand Down Expand Up @@ -1207,6 +1251,7 @@ auto cxx::lowerToMLIR(mlir::ModuleOp module) -> mlir::LogicalResult {

pm.addPass(cxx::createLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());

if (failed(pm.run(module))) {
module.print(llvm::errs());
Expand Down
Loading