diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 048d1208..290384c0 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -212,6 +212,14 @@ def Cxx_MulIOp : Cxx_Op<"muli"> { let results = (outs Cxx_IntegerType:$result); } +def Cxx_LabelOp : Cxx_Op<"label"> { + let arguments = (ins StringProp:$name); +} + +def Cxx_GotoOp : Cxx_Op<"goto"> { + let arguments = (ins StringProp:$label); +} + def CondBranchOp : Cxx_Op<"cond_br", [ AttrSizedOperandSegments, Terminator ]> { let arguments = (ins Cxx_BoolType:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands); diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index 6d8993c1..7b910861 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -41,7 +41,9 @@ auto Codegen::control() const -> Control* { return unit_->control(); } auto Codegen::currentBlockMightHaveTerminator() -> bool { auto block = builder_.getInsertionBlock(); - if (!block) return true; + if (!block) { + cxx_runtime_error("current block is null"); + } return block->mightHaveTerminator(); } diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index 502372a4..a4df8cac 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -22,9 +22,12 @@ #include #include +#include #include #include #include + +// mlir #include #include #include diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index 0b1492b4..8e30fb8f 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -369,7 +369,7 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast) exitValue = gen.builder_.create(exitValueLoc, ptrType); } - // restore state + // function state std::swap(gen.function_, func); std::swap(gen.exitBlock_, exitBlock); std::swap(gen.exitValue_, exitValue); diff --git a/src/mlir/cxx/mlir/codegen_statements.cc b/src/mlir/cxx/mlir/codegen_statements.cc index fb765315..f40f2f28 100644 --- a/src/mlir/cxx/mlir/codegen_statements.cc +++ b/src/mlir/cxx/mlir/codegen_statements.cc @@ -23,6 +23,7 @@ // cxx #include #include +#include // mlir #include @@ -91,7 +92,12 @@ auto Codegen::handler(HandlerAST* ast) -> HandlerResult { } void Codegen::StatementVisitor::operator()(LabeledStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + auto targetBlock = gen.newBlock(); + gen.branch(gen.getLocation(ast->firstSourceLocation()), targetBlock); + gen.builder_.setInsertionPointToEnd(targetBlock); + + gen.builder_.create( + gen.getLocation(ast->firstSourceLocation()), ast->identifier->name()); } void Codegen::StatementVisitor::operator()(CaseStatementAST* ast) { @@ -291,7 +297,19 @@ void Codegen::StatementVisitor::operator()(CoroutineReturnStatementAST* ast) { } void Codegen::StatementVisitor::operator()(GotoStatementAST* ast) { - (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + if (ast->isIndirect) { + (void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind())); + return; + } + + gen.builder_.create( + gen.getLocation(ast->firstSourceLocation()), mlir::ValueRange{}, + ast->identifier->name()); + + auto nextBlock = gen.newBlock(); + gen.branch(gen.getLocation(ast->firstSourceLocation()), nextBlock); + + gen.builder_.setInsertionPointToEnd(nextBlock); } void Codegen::StatementVisitor::operator()(DeclarationStatementAST* ast) { diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index ca21c7cf..5d3d8e86 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -459,6 +459,53 @@ class CondBranchOpLowering : public OpConversionPattern { } }; +struct LabelConverter { + DenseMap labels; +}; + +class GotoOpLowering : public OpConversionPattern { + public: + GotoOpLowering(const TypeConverter &typeConverter, + const LabelConverter &labelConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + labelConverter_(labelConverter) {} + + auto matchAndRewrite(cxx::GotoOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto context = getContext(); + + auto targetBlock = labelConverter_.labels.lookup(op.getLabel()); + + if (auto nextOp = ++op->getIterator(); isa(&*nextOp)) { + rewriter.eraseOp(&*nextOp); + } + + rewriter.replaceOpWithNewOp(op, targetBlock); + + return success(); + } + + private: + const LabelConverter &labelConverter_; +}; + +class LabelOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::LabelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + -> LogicalResult override { + auto context = getContext(); + + rewriter.eraseOp(op); + + return success(); + } +}; + class CxxToLLVMLoweringPass : public PassWrapper> { public: @@ -553,6 +600,17 @@ void CxxToLLVMLoweringPass::runOnOperation() { SubIOpLowering, MulIOpLowering, CondBranchOpLowering>( typeConverter, context); + LabelConverter labelConverter; + + module.walk([&](Operation *op) { + if (auto labelOp = dyn_cast(op)) { + labelConverter.labels[labelOp.getName()] = labelOp->getBlock(); + } + }); + + patterns.insert(typeConverter, context); + patterns.insert(typeConverter, labelConverter, context); + populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);