Skip to content

Commit 09bd519

Browse files
committed
[Heavy] Allow default cases and noops in OverloadOp
1 parent 9d65abb commit 09bd519

File tree

4 files changed

+47
-28
lines changed

4 files changed

+47
-28
lines changed

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,14 @@ def Nbdl_MatchOp : Nbdl_Op<"match", [Terminator, NoTerminator]> {
178178

179179
def Nbdl_OverloadOp : Nbdl_Op<"overload", []> {
180180
let description = [{
181-
In the region of a MatchOp we specify a function overload of the
182-
given typename using OverloadOp.
183-
The body of the overload takes an argument that is the matched
184-
object and the rest of the arguments are the captures.
181+
Visit the body when the input matches a specific typename
182+
not including any qualifiers or when the type is an empty string.
183+
If the body is not empty, it takes an argument that is the matched
184+
object. An empty body is a noop.
185185
}];
186186
let arguments = (ins StrAttr:$type);
187187
let results = (outs);
188-
let regions = (region SizedRegion<1>:$body);
188+
let regions = (region AnyRegion:$body);
189189
}
190190

191191
def Nbdl_NoOp : Nbdl_Op<"noop", [Terminator]> {

heavy/lib/Nbdl.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ heavy::ExternFunction build_match_params_impl;
3333
heavy::ExternFunction build_overload_impl;
3434
heavy::ExternFunction build_match_if_impl;
3535
heavy::ExternFunction build_context_impl;
36+
heavy::ExternFunction build_match_op_impl;
3637
}
3738

3839
namespace {
@@ -56,9 +57,10 @@ std::optional<mlir::OpBuilder> getModuleBuilder(heavy::Context& C) {
5657
namespace heavy::nbdl_bind {
5758
// Create a function and call the thunk with a new builder
5859
// to insert into the function body.
59-
// _num_params_ does not include the store parameter.
60-
// _callback_ takes _num_params_ + 1 arguments which are the block arguments.
61-
// (%build_match_params _name_ _num_params_ _callback_)
60+
// _num_store_params_ N
61+
// _callback_ takes _num_store_params_ + 1 arguments which are the block arguments
62+
// with formals like (store1 store2 ... storeN fn)
63+
// (%build_match_params _name_ _num_store_params_ _callback_)
6264
void build_match_params_impl(Context& C, ValueRefs Args) {
6365
if (Args.size() != 3)
6466
return C.RaiseError("invalid arity");
@@ -80,18 +82,21 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
8082
: int32_t(-1);
8183

8284
if (NumParams < 0)
83-
return C.RaiseError("expecting positive integer for num_params");
85+
return C.RaiseError("expecting positive integer for num_store_params");
8486
if (!NameSym || Name.empty())
8587
return C.RaiseError("expecting function name (symbol literal)");
8688

8789
mlir::Location MLoc = mlir::OpaqueLoc::get(Loc.getOpaqueEncoding(),
8890
Builder.getContext());
8991

9092
// Create the function type.
91-
llvm::SmallVector<mlir::Type, 8> InputTypes{
92-
Builder.getType<nbdl_gen::StoreType>()};
93+
llvm::SmallVector<mlir::Type, 8> InputTypes;
9394
for (unsigned i = 0; i < static_cast<uint32_t>(NumParams); i++)
94-
InputTypes.push_back(Builder.getType<nbdl_gen::OpaqueType>());
95+
InputTypes.push_back(Builder.getType<nbdl_gen::StoreType>());
96+
97+
// Push the visitor fn argument.
98+
InputTypes.push_back(Builder.getType<nbdl_gen::OpaqueType>());
99+
95100
mlir::FunctionType FT = Builder.getFunctionType(InputTypes,
96101
/*ResultTypes*/{});
97102

@@ -116,10 +121,11 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
116121
mlir_helper::with_builder_impl(C, Builder, Thunk);
117122
}
118123

119-
// _callback_ takes a single argument is the block argument.
124+
// _callback_ takes a single block argument if specified.
120125
// (%build_overload _loc_ _typename_ _callback_)
126+
// (%build_overload _loc_ _typename_)
121127
void build_overload_impl(Context& C, ValueRefs Args) {
122-
if (Args.size() != 3)
128+
if (Args.size() != 2 && Args.size() != 3)
123129
return C.RaiseError("invalid arity");
124130

125131
mlir::OpBuilder* Builder = mlir_helper::getCurrentBuilder(C);
@@ -128,15 +134,19 @@ void build_overload_impl(Context& C, ValueRefs Args) {
128134

129135
// Create the operation with an entry block with a single argument.
130136
heavy::SourceLocation Loc = Args[0].getSourceLocation();
131-
llvm::StringRef Typename = Args[1].getStringRef();
132-
heavy::Value Callback = Args[2];
137+
heavy::Value TypenameArg = Args[1];
138+
heavy::Value Callback = Args.size() == 3 ? Args[2] : nullptr;
133139

134-
if (Typename.empty())
135-
return C.RaiseError("expecting nonempty string-like object for typename");
140+
if (!isa<String>(TypenameArg) && !isa<Symbol>(TypenameArg))
141+
return C.RaiseError("expecting string-like object for typename");
136142

143+
llvm::StringRef Typename = TypenameArg.getStringRef();
137144
mlir::Location MLoc = mlir::OpaqueLoc::get(Loc.getOpaqueEncoding(),
138145
Builder->getContext());
139146
auto OverloadOp = Builder->create<nbdl_gen::OverloadOp>(MLoc, Typename);
147+
if (!Callback)
148+
return C.Cont();
149+
140150
OverloadOp.getBody().emplaceBlock();
141151
assert(OverloadOp.getBody().hasOneBlock() && "expecting a single block");
142152

@@ -293,6 +303,12 @@ void build_context_impl(Context& C, ValueRefs Args) {
293303
Builder = mlir::OpBuilder(ContextOp.getBody());
294304
mlir_helper::with_builder_impl(C, Builder, Thunk);
295305
}
306+
307+
308+
// (%build-match-op name storeval keyval
309+
void build_match_op_impl(Context& C, ValueRefs Args) {
310+
C.Cont();
311+
}
296312
} // end namespace heavy::nbdl_bind
297313

298314
extern "C" {
@@ -316,6 +332,8 @@ void HEAVY_NBDL_INIT(heavy::Context& C) {
316332
= heavy::nbdl_bind::build_match_if_impl;
317333
heavy::nbdl_bind_var::build_context_impl
318334
= heavy::nbdl_bind::build_context_impl;
335+
heavy::nbdl_bind_var::build_match_op_impl
336+
= heavy::nbdl_bind::build_match_op_impl;
319337
}
320338

321339
void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
@@ -327,6 +345,7 @@ void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
327345
{"%build-overload", heavy::nbdl_bind_var::build_overload_impl},
328346
{"%build-match-if", heavy::nbdl_bind_var::build_match_if_impl},
329347
{"%build-context", heavy::nbdl_bind_var::build_context_impl},
348+
{"%build-match-op", heavy::nbdl_bind_var::build_match_op_impl},
330349
});
331350
}
332351
}

heavy/lib/Nbdl/NbdlWriter.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <heavy/Source.h>
1616
#include <heavy/Value.h>
1717
#include <llvm/ADT/ScopedHashTable.h>
18+
#include <llvm/ADT/ScopeExit.h>
1819
#include <llvm/ADT/Twine.h>
1920
#include <llvm/Support/Casting.h>
2021
#include <mlir/IR/Value.h>
@@ -268,21 +269,24 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
268269
Flush();
269270
heavy::SourceLocation PrevLoc = CurLoc;
270271
SetLoc(Op->getLoc());
272+
271273
if (CheckError()) return;
274+
auto ScopeExit = llvm::make_scope_exit([this, PrevLoc] {
275+
Flush();
276+
CurLoc = PrevLoc;
277+
});
278+
272279
if (isa<ApplyOp>(Op)) return Visit(cast<ApplyOp>(Op));
273280
else if (isa<GetOp>(Op)) return Visit(cast<GetOp>(Op));
274281
else if (isa<VisitOp>(Op)) return Visit(cast<VisitOp>(Op));
275282
else if (isa<MatchOp>(Op)) return Visit(cast<MatchOp>(Op));
276283
else if (isa<OverloadOp>(Op)) return Visit(cast<OverloadOp>(Op));
277284
else if (isa<MatchIfOp>(Op)) return Visit(cast<MatchIfOp>(Op));
278-
else if (isa<NoOp>(Op)) return Visit(cast<NoOp>(Op));
279285
else if (isa<FuncOp>(Op)) return Visit(cast<FuncOp>(Op));
280286
else if (isa<MemberNameOp>(Op)) return Visit(cast<MemberNameOp>(Op));
281287
else if (isa<ConstexprOp, LiteralOp>(Op)) return;
282288
else
283289
SetError("unhandled operation", Op);
284-
Flush();
285-
CurLoc = PrevLoc;
286290
}
287291

288292
void VisitRegion(mlir::Region& R) {
@@ -370,6 +374,8 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
370374

371375
void Visit(OverloadOp Op) {
372376
mlir::Region& Body = Op.getBody();
377+
if (Body.empty())
378+
return;
373379
OS << "[&]";
374380
// Write parameters.
375381
OS << '(';
@@ -423,10 +429,6 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
423429
// We could implement in MatchOp, but it is a very
424430
// unlikely use case.
425431
}
426-
427-
void Visit(NoOp Op) {
428-
// Do nothing.
429-
}
430432
};
431433

432434
class ContextWriter : public NbdlWriter<ContextWriter> {

heavy/test/Nbdl/match_params.scm

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@
8787
(operands some-fn input)
8888
(result-types !nbdl.opaque))))
8989
(create-op "nbdl.visit" (operands fn some-fn-result))))))))
90-
(%build-overload 'loc "auto&&"
91-
(lambda (anything)
92-
(create-op "nbdl.noop")))
90+
(%build-overload 'loc "")
9391
))))
9492

9593
; CHECK: #op{module @nbdl_gen_module {

0 commit comments

Comments
 (0)