@@ -33,6 +33,7 @@ heavy::ExternFunction build_match_params_impl;
3333heavy::ExternFunction build_overload_impl;
3434heavy::ExternFunction build_match_if_impl;
3535heavy::ExternFunction build_context_impl;
36+ heavy::ExternFunction build_match_op_impl;
3637}
3738
3839namespace {
@@ -56,9 +57,10 @@ std::optional<mlir::OpBuilder> getModuleBuilder(heavy::Context& C) {
5657namespace heavy ::nbdl_bind {
5758// Create a function and call the thunk with a new builder
5859// to insert into the function body.
59- // _num_params_ does not include the store parameter.
60- // _callback_ takes _num_params_ + 1 arguments which are the block arguments.
61- // (%build_match_params _name_ _num_params_ _callback_)
60+ // _num_store_params_ N
61+ // _callback_ takes _num_store_params_ + 1 arguments which are the block arguments
62+ // with formals like (store1 store2 ... storeN fn)
63+ // (%build_match_params _name_ _num_store_params_ _callback_)
6264void build_match_params_impl (Context& C, ValueRefs Args) {
6365 if (Args.size () != 3 )
6466 return C.RaiseError (" invalid arity" );
@@ -80,18 +82,21 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
8082 : int32_t (-1 );
8183
8284 if (NumParams < 0 )
83- return C.RaiseError (" expecting positive integer for num_params " );
85+ return C.RaiseError (" expecting positive integer for num_store_params " );
8486 if (!NameSym || Name.empty ())
8587 return C.RaiseError (" expecting function name (symbol literal)" );
8688
8789 mlir::Location MLoc = mlir::OpaqueLoc::get (Loc.getOpaqueEncoding (),
8890 Builder.getContext ());
8991
9092 // Create the function type.
91- llvm::SmallVector<mlir::Type, 8 > InputTypes{
92- Builder.getType <nbdl_gen::StoreType>()};
93+ llvm::SmallVector<mlir::Type, 8 > InputTypes;
9394 for (unsigned i = 0 ; i < static_cast <uint32_t >(NumParams); i++)
94- InputTypes.push_back (Builder.getType <nbdl_gen::OpaqueType>());
95+ InputTypes.push_back (Builder.getType <nbdl_gen::StoreType>());
96+
97+ // Push the visitor fn argument.
98+ InputTypes.push_back (Builder.getType <nbdl_gen::OpaqueType>());
99+
95100 mlir::FunctionType FT = Builder.getFunctionType (InputTypes,
96101 /* ResultTypes*/ {});
97102
@@ -116,10 +121,11 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
116121 mlir_helper::with_builder_impl (C, Builder, Thunk);
117122}
118123
119- // _callback_ takes a single argument is the block argument.
124+ // _callback_ takes a single block argument if specified .
120125// (%build_overload _loc_ _typename_ _callback_)
126+ // (%build_overload _loc_ _typename_)
121127void build_overload_impl (Context& C, ValueRefs Args) {
122- if (Args.size () != 3 )
128+ if (Args.size () != 2 && Args. size () != 3 )
123129 return C.RaiseError (" invalid arity" );
124130
125131 mlir::OpBuilder* Builder = mlir_helper::getCurrentBuilder (C);
@@ -128,15 +134,19 @@ void build_overload_impl(Context& C, ValueRefs Args) {
128134
129135 // Create the operation with an entry block with a single argument.
130136 heavy::SourceLocation Loc = Args[0 ].getSourceLocation ();
131- llvm::StringRef Typename = Args[1 ]. getStringRef () ;
132- heavy::Value Callback = Args[2 ];
137+ heavy::Value TypenameArg = Args[1 ];
138+ heavy::Value Callback = Args. size () == 3 ? Args [2 ] : nullptr ;
133139
134- if (Typename. empty ( ))
135- return C.RaiseError (" expecting nonempty string-like object for typename" );
140+ if (!isa<String>(TypenameArg) && !isa<Symbol>(TypenameArg ))
141+ return C.RaiseError (" expecting string-like object for typename" );
136142
143+ llvm::StringRef Typename = TypenameArg.getStringRef ();
137144 mlir::Location MLoc = mlir::OpaqueLoc::get (Loc.getOpaqueEncoding (),
138145 Builder->getContext ());
139146 auto OverloadOp = Builder->create <nbdl_gen::OverloadOp>(MLoc, Typename);
147+ if (!Callback)
148+ return C.Cont ();
149+
140150 OverloadOp.getBody ().emplaceBlock ();
141151 assert (OverloadOp.getBody ().hasOneBlock () && " expecting a single block" );
142152
@@ -293,6 +303,12 @@ void build_context_impl(Context& C, ValueRefs Args) {
293303 Builder = mlir::OpBuilder (ContextOp.getBody ());
294304 mlir_helper::with_builder_impl (C, Builder, Thunk);
295305}
306+
307+
308+ // (%build-match-op name storeval keyval
309+ void build_match_op_impl (Context& C, ValueRefs Args) {
310+ C.Cont ();
311+ }
296312} // end namespace heavy::nbdl_bind
297313
298314extern " C" {
@@ -316,6 +332,8 @@ void HEAVY_NBDL_INIT(heavy::Context& C) {
316332 = heavy::nbdl_bind::build_match_if_impl;
317333 heavy::nbdl_bind_var::build_context_impl
318334 = heavy::nbdl_bind::build_context_impl;
335+ heavy::nbdl_bind_var::build_match_op_impl
336+ = heavy::nbdl_bind::build_match_op_impl;
319337}
320338
321339void HEAVY_NBDL_LOAD_MODULE (heavy::Context& C) {
@@ -327,6 +345,7 @@ void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
327345 {" %build-overload" , heavy::nbdl_bind_var::build_overload_impl},
328346 {" %build-match-if" , heavy::nbdl_bind_var::build_match_if_impl},
329347 {" %build-context" , heavy::nbdl_bind_var::build_context_impl},
348+ {" %build-match-op" , heavy::nbdl_bind_var::build_match_op_impl},
330349 });
331350}
332351}
0 commit comments