@@ -29,6 +29,7 @@ translate_cpp(heavy::LexerWriterFnRef FnRef, mlir::Operation* Op);
2929namespace heavy ::nbdl_bind_var {
3030heavy::ContextLocal current_nbdl_module;
3131heavy::ExternFunction translate_cpp;
32+ heavy::ExternFunction close_previous_scope;
3233heavy::ExternFunction build_match_params_impl;
3334heavy::ExternFunction build_overload_impl;
3435heavy::ExternFunction build_match_if_impl;
@@ -95,7 +96,7 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
9596 InputTypes.push_back (Builder.getType <nbdl_gen::StoreType>());
9697
9798 // Push the visitor fn argument.
98- InputTypes.push_back (Builder.getType <nbdl_gen::OpaqueType >());
99+ InputTypes.push_back (Builder.getType <nbdl_gen::StoreType >());
99100
100101 mlir::FunctionType FT = Builder.getFunctionType (InputTypes,
101102 /* ResultTypes*/ {});
@@ -113,6 +114,9 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
113114 BlockArgs.push_back (V);
114115 }
115116
117+ C.PushCont ([FuncOp](Context& C, ValueRefs) mutable {
118+ C.Cont (FuncOp.getOperation ());
119+ }, CaptureList{});
116120 C.Apply (Callback, BlockArgs);
117121 }, CaptureList{Callback});
118122
@@ -156,7 +160,7 @@ void build_overload_impl(Context& C, ValueRefs Args) {
156160 mlir::Region& Body = OverloadOp.getBody ();
157161 mlir::Location MLoc = OverloadOp.getLoc ();
158162 mlir::Type Type = mlir::OpBuilder (OverloadOp)
159- .getType <nbdl_gen::OpaqueType >();
163+ .getType <nbdl_gen::StoreType >();
160164 mlir::BlockArgument BlockArg = Body.addArgument (Type, MLoc);
161165 heavy::Value V = C.CreateAny (mlir::Value (BlockArg));
162166 C.Apply (Callback, V);
@@ -203,6 +207,8 @@ void build_match_if_impl(Context& C, ValueRefs Args) {
203207
204208// Translate a nbdl dialect operation to C++.
205209// (translate-cpp op port)
210+ // The parameter `op` may be an mlir::Operation* or a StringLike
211+ // which will be used to look up the name in the module.
206212// Currently the "port" has to be a tagged llvm::raw_ostream.
207213void translate_cpp (Context& C, ValueRefs Args) {
208214 if (Args.size () != 2 && Args.size () != 1 )
@@ -249,6 +255,35 @@ void translate_cpp(Context& C, ValueRefs Args) {
249255 C.Cont ();
250256}
251257
258+ // If the current block has a terminator, wrap the
259+ // entire block in a nbdl.scope. This supports the
260+ // convention that only terminators may perform an
261+ // operation that may invalidate child stores.
262+ void close_previous_scope (Context& C, ValueRefs Args) {
263+ if (Args.size () != 0 )
264+ return C.RaiseError (" invalid arity" );
265+ mlir::OpBuilder* Builder = mlir_helper::getCurrentBuilder (C);
266+ if (!Builder)
267+ return ; // error is already raised by getCurrentBuilder
268+ mlir::Block* Block = Builder->getBlock ();
269+ assert (Block && isa<mlir::func::FuncOp>(Block->getParentOp ())
270+ && " expecting func insertion point" );
271+ if (Block->empty () || !Block->back ().hasTrait <mlir::OpTrait::IsTerminator>())
272+ return C.Cont ();
273+
274+ mlir::Location Loc = Block->back ().getLoc ();
275+
276+ // Create new Region for ScopeOp.
277+ auto ScopeBody = std::make_unique<mlir::Region>();
278+ mlir::Block& NewBlock = ScopeBody->emplaceBlock ();
279+ while (!Block->empty ())
280+ Block->front ().moveBefore (&NewBlock, NewBlock.end ());
281+ mlir::Operation* ScopeOp = Builder->create <nbdl_gen::ScopeOp>(Loc, std::move (ScopeBody));
282+ Builder->setInsertionPointAfter (ScopeOp);
283+
284+ C.Cont ();
285+ }
286+
252287void build_context_impl (Context& C, ValueRefs Args) {
253288 if (Args.size () != 3 )
254289 return C.RaiseError (" invalid arity" );
@@ -282,9 +317,9 @@ void build_context_impl(Context& C, ValueRefs Args) {
282317 mlir::Block& EntryBlock = ContextOp.getBody ().emplaceBlock ();
283318
284319 // Create the arguments.
285- auto OpaqueTy = Builder.getType <nbdl_gen::OpaqueType >();
320+ auto StoreTy = Builder.getType <nbdl_gen::StoreType >();
286321 for (int32_t i = 0 ; i < NumParams; i++)
287- EntryBlock.addArgument (OpaqueTy , MLoc);
322+ EntryBlock.addArgument (StoreTy , MLoc);
288323
289324 heavy::Value Thunk = C.CreateLambda ([ContextOp](Context& C,
290325 ValueRefs) mutable {
@@ -324,6 +359,8 @@ void HEAVY_NBDL_INIT(heavy::Context& C) {
324359
325360 heavy::nbdl_bind_var::current_nbdl_module.init (C, ModuleOp.getOperation ());
326361 heavy::nbdl_bind_var::translate_cpp = heavy::nbdl_bind::translate_cpp;
362+ heavy::nbdl_bind_var::close_previous_scope
363+ = heavy::nbdl_bind::close_previous_scope;
327364 heavy::nbdl_bind_var::build_match_params_impl
328365 = heavy::nbdl_bind::build_match_params_impl;
329366 heavy::nbdl_bind_var::build_overload_impl
@@ -341,6 +378,8 @@ void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
341378 heavy::initModuleNames (C, HEAVY_NBDL_LIB_STR, {
342379 {" current-nbdl-module" , heavy::nbdl_bind_var::current_nbdl_module.get (C)},
343380 {" translate-cpp" , heavy::nbdl_bind_var::translate_cpp},
381+ {" close-previous-scope" ,
382+ heavy::nbdl_bind_var::close_previous_scope},
344383 {" %build-match-params" , heavy::nbdl_bind_var::build_match_params_impl},
345384 {" %build-overload" , heavy::nbdl_bind_var::build_overload_impl},
346385 {" %build-match-if" , heavy::nbdl_bind_var::build_match_if_impl},
0 commit comments