Skip to content

Commit d12d8ed

Browse files
committed
[Heavy] Add ApplyActionOp
1 parent 5a40459 commit d12d8ed

File tree

7 files changed

+285
-79
lines changed

7 files changed

+285
-79
lines changed

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Nbdl_TypeBase<string name, string type_mnemonic,
4949
let assemblyFormat = "`<` $cpp_typename `>`";
5050
}
5151

52+
// DEPRECATED (I think)
5253
def Nbdl_Opaque : TypeDef<Nbdl_Dialect, "Opaque", []> {
5354
let mnemonic = "opaque";
5455
let description = [{
@@ -70,29 +71,12 @@ def Nbdl_Store : TypeDef<Nbdl_Dialect, "Store", []> {
7071
}];
7172
}
7273

73-
/* FIXME Types should only be based on corresponding concepts in nbdl.
74-
def Nbdl_Struct : Nbdl_TypeBase<"Struct", "struct"> {
75-
let summary = "C++ semiregular aggregate type";
76-
let description = [{
77-
Represent C++ struct types whose members
78-
are accessed by their name.
79-
}];
80-
}
81-
82-
def Nbdl_State : TypeDef<Nbdl_Dialect, "State", []> {
83-
let mnemonic = "state";
84-
let description = [{
85-
A type satisfying the nbdl::State concept
86-
}];
87-
}
88-
8974
def Nbdl_Variant : TypeDef<Nbdl_Dialect, "Variant", []> {
9075
let mnemonic = "variant";
9176
let description = [{
9277
A type satisfying the nbdl::Variant concept
9378
}];
9479
}
95-
*/
9680

9781
def Nbdl_Tag : TypeDef<Nbdl_Dialect, "Tag", []> {
9882
let mnemonic = "tag";
@@ -112,10 +96,10 @@ def Nbdl_Symbol : TypeDef<Nbdl_Dialect, "Symbol"> {
11296

11397
// Keys are optional, but we are using variadic arguments
11498
// so we resort to using this sum type.
115-
def Nbdl_KeyType : AnyTypeOf<[Nbdl_Opaque, Nbdl_Tag,
99+
def Nbdl_KeyType : AnyTypeOf<[Nbdl_Store, Nbdl_Tag,
116100
Nbdl_Symbol, Nbdl_Unit]>;
117101

118-
def Nbdl_Type : AnyTypeOf<[Nbdl_Unit, Nbdl_Empty, Nbdl_Opaque,
102+
def Nbdl_Type : AnyTypeOf<[Nbdl_Unit, Nbdl_Empty,
119103
Nbdl_Store, Nbdl_Tag]>;
120104

121105
class Nbdl_Op<string mnemonic, list<Trait> traits = []> :
@@ -141,9 +125,12 @@ def Nbdl_GetOp : Nbdl_Op<"get", []> {
141125
let results = (outs Nbdl_Type:$result);
142126
}
143127

144-
def Nbdl_VisitOp : Nbdl_Op<"visit", [Terminator]> {
128+
def Nbdl_ResolveOp : Nbdl_Op<"resolve", [Terminator]> {
145129
let description = [{
146-
Visit matched results with a function object.
130+
Finalize resolving stores by invoking the callback
131+
provided in the definition of a visit function.
132+
This is similar to nbdl.visit except it does not return
133+
a value and terminates the block.
147134
}];
148135

149136
let arguments = (ins Nbdl_Type:$fn, Variadic<Nbdl_Type>:$args);
@@ -202,7 +189,7 @@ def Nbdl_MatchIfOp : Nbdl_Op<"match_if", [Terminator]> {
202189
like an if/else statement.
203190
}];
204191

205-
let arguments = (ins Nbdl_Type:$input, Nbdl_Opaque:$pred);
192+
let arguments = (ins Nbdl_Type:$input, Nbdl_Type:$pred);
206193
let regions = (region SizedRegion<1>:$thenRegion,
207194
SizedRegion<1>:$elseRegion);
208195
}
@@ -322,13 +309,24 @@ def Nbdl_UnitOp : Nbdl_Op<"unit", []> {
322309
let results = (outs Nbdl_Unit:$result);
323310
}
324311

325-
def Nbdl_ApplyOp : Nbdl_Op<"apply", []> {
312+
def Nbdl_VisitOp : Nbdl_Op<"visit", []> {
326313
let description = [{
327-
Apply a function.
314+
Call a function with resolved stores as arguments
315+
returning a result. This is analogous to std::visit
316+
except only one function is generated.
328317
}];
329-
let arguments = (ins Nbdl_Opaque:$fn,
318+
let arguments = (ins Nbdl_Type:$fn,
330319
Variadic<Nbdl_Type>:$args);
331-
let results = (outs Nbdl_Opaque:$result);
320+
let results = (outs Nbdl_Type:$result);
321+
}
322+
323+
def Nbdl_ApplyActionOp : Nbdl_Op<"apply_action", [Terminator]> {
324+
let description = [{
325+
Create a call to nbdl::apply_action which is dispatched
326+
by the type of the LHS in C++.
327+
}];
328+
let arguments = (ins Nbdl_Type:$lhs, Variadic<Nbdl_Type>:$args);
329+
let results = (outs);
332330
}
333331

334332
def Nbdl_ConsumerOp : Nbdl_Op<"consumer", []> {
@@ -345,4 +343,38 @@ def Nbdl_ConsumerOp : Nbdl_Op<"consumer", []> {
345343
let results = (outs Nbdl_Type:$result);
346344
}
347345

346+
def Nbdl_ConstOp : Nbdl_Op<"const", []> {
347+
let description = [{
348+
Cast an object to const likely using std::as_const.
349+
350+
Notes:
351+
We don't modify the type in IR land but rely on C++
352+
to manage underlying types. This is analogous to an
353+
immutable borrow I think.
354+
We generally only allow mutation to occur to the operands
355+
of nbdl.apply_action or nbdl.resolve.
356+
}];
357+
358+
let arguments = (ins Nbdl_Type:$arg);
359+
let results = (outs Nbdl_Type:$result);
360+
}
361+
362+
def Nbdl_ScopeOp : Nbdl_Op<"scope", []> {
363+
let description = [{
364+
Terminating operations in a function may mutate a store
365+
thus invalidating any matched store with that object as
366+
a node in its root path. We can nest intermediate operations
367+
in a nbdl.scope region to prevent subsequent uses of such
368+
objects.
369+
}];
370+
371+
let arguments = (ins);
372+
let results = (outs);
373+
let regions = (region SizedRegion<1>:$body);
374+
375+
let builders = [
376+
OpBuilder<(ins "std::unique_ptr<::mlir::Region>":$body)>
377+
];
378+
}
379+
348380
#endif

heavy/lib/Nbdl.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ translate_cpp(heavy::LexerWriterFnRef FnRef, mlir::Operation* Op);
2929
namespace heavy::nbdl_bind_var {
3030
heavy::ContextLocal current_nbdl_module;
3131
heavy::ExternFunction translate_cpp;
32+
heavy::ExternFunction close_previous_scope;
3233
heavy::ExternFunction build_match_params_impl;
3334
heavy::ExternFunction build_overload_impl;
3435
heavy::ExternFunction build_match_if_impl;
@@ -95,7 +96,7 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
9596
InputTypes.push_back(Builder.getType<nbdl_gen::StoreType>());
9697

9798
// Push the visitor fn argument.
98-
InputTypes.push_back(Builder.getType<nbdl_gen::OpaqueType>());
99+
InputTypes.push_back(Builder.getType<nbdl_gen::StoreType>());
99100

100101
mlir::FunctionType FT = Builder.getFunctionType(InputTypes,
101102
/*ResultTypes*/{});
@@ -113,6 +114,9 @@ void build_match_params_impl(Context& C, ValueRefs Args) {
113114
BlockArgs.push_back(V);
114115
}
115116

117+
C.PushCont([FuncOp](Context& C, ValueRefs) mutable {
118+
C.Cont(FuncOp.getOperation());
119+
}, CaptureList{});
116120
C.Apply(Callback, BlockArgs);
117121
}, CaptureList{Callback});
118122

@@ -156,7 +160,7 @@ void build_overload_impl(Context& C, ValueRefs Args) {
156160
mlir::Region& Body = OverloadOp.getBody();
157161
mlir::Location MLoc = OverloadOp.getLoc();
158162
mlir::Type Type = mlir::OpBuilder(OverloadOp)
159-
.getType<nbdl_gen::OpaqueType>();
163+
.getType<nbdl_gen::StoreType>();
160164
mlir::BlockArgument BlockArg = Body.addArgument(Type, MLoc);
161165
heavy::Value V = C.CreateAny(mlir::Value(BlockArg));
162166
C.Apply(Callback, V);
@@ -203,6 +207,8 @@ void build_match_if_impl(Context& C, ValueRefs Args) {
203207

204208
// Translate a nbdl dialect operation to C++.
205209
// (translate-cpp op port)
210+
// The parameter `op` may be an mlir::Operation* or a StringLike
211+
// which will be used to look up the name in the module.
206212
// Currently the "port" has to be a tagged llvm::raw_ostream.
207213
void translate_cpp(Context& C, ValueRefs Args) {
208214
if (Args.size() != 2 && Args.size() != 1)
@@ -249,6 +255,35 @@ void translate_cpp(Context& C, ValueRefs Args) {
249255
C.Cont();
250256
}
251257

258+
// If the current block has a terminator, wrap the
259+
// entire block in a nbdl.scope. This supports the
260+
// convention that only terminators may perform an
261+
// operation that may invalidate child stores.
262+
void close_previous_scope(Context& C, ValueRefs Args) {
263+
if (Args.size() != 0)
264+
return C.RaiseError("invalid arity");
265+
mlir::OpBuilder* Builder = mlir_helper::getCurrentBuilder(C);
266+
if (!Builder)
267+
return; // error is already raised by getCurrentBuilder
268+
mlir::Block* Block = Builder->getBlock();
269+
assert(Block && isa<mlir::func::FuncOp>(Block->getParentOp())
270+
&& "expecting func insertion point");
271+
if (Block->empty() || !Block->back().hasTrait<mlir::OpTrait::IsTerminator>())
272+
return C.Cont();
273+
274+
mlir::Location Loc = Block->back().getLoc();
275+
276+
// Create new Region for ScopeOp.
277+
auto ScopeBody = std::make_unique<mlir::Region>();
278+
mlir::Block& NewBlock = ScopeBody->emplaceBlock();
279+
while (!Block->empty())
280+
Block->front().moveBefore(&NewBlock, NewBlock.end());
281+
mlir::Operation* ScopeOp = Builder->create<nbdl_gen::ScopeOp>(Loc, std::move(ScopeBody));
282+
Builder->setInsertionPointAfter(ScopeOp);
283+
284+
C.Cont();
285+
}
286+
252287
void build_context_impl(Context& C, ValueRefs Args) {
253288
if (Args.size() != 3)
254289
return C.RaiseError("invalid arity");
@@ -282,9 +317,9 @@ void build_context_impl(Context& C, ValueRefs Args) {
282317
mlir::Block& EntryBlock = ContextOp.getBody().emplaceBlock();
283318

284319
// Create the arguments.
285-
auto OpaqueTy = Builder.getType<nbdl_gen::OpaqueType>();
320+
auto StoreTy = Builder.getType<nbdl_gen::StoreType>();
286321
for (int32_t i = 0; i < NumParams; i++)
287-
EntryBlock.addArgument(OpaqueTy, MLoc);
322+
EntryBlock.addArgument(StoreTy, MLoc);
288323

289324
heavy::Value Thunk = C.CreateLambda([ContextOp](Context& C,
290325
ValueRefs) mutable {
@@ -324,6 +359,8 @@ void HEAVY_NBDL_INIT(heavy::Context& C) {
324359

325360
heavy::nbdl_bind_var::current_nbdl_module.init(C, ModuleOp.getOperation());
326361
heavy::nbdl_bind_var::translate_cpp = heavy::nbdl_bind::translate_cpp;
362+
heavy::nbdl_bind_var::close_previous_scope
363+
= heavy::nbdl_bind::close_previous_scope;
327364
heavy::nbdl_bind_var::build_match_params_impl
328365
= heavy::nbdl_bind::build_match_params_impl;
329366
heavy::nbdl_bind_var::build_overload_impl
@@ -341,6 +378,8 @@ void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
341378
heavy::initModuleNames(C, HEAVY_NBDL_LIB_STR, {
342379
{"current-nbdl-module", heavy::nbdl_bind_var::current_nbdl_module.get(C)},
343380
{"translate-cpp", heavy::nbdl_bind_var::translate_cpp},
381+
{"close-previous-scope",
382+
heavy::nbdl_bind_var::close_previous_scope},
344383
{"%build-match-params", heavy::nbdl_bind_var::build_match_params_impl},
345384
{"%build-overload", heavy::nbdl_bind_var::build_overload_impl},
346385
{"%build-match-if", heavy::nbdl_bind_var::build_match_if_impl},

heavy/lib/Nbdl/NbdlDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ void nbdl_gen::NbdlDialect::initialize() {
4141
>();
4242

4343
}
44+
45+
void nbdl_gen::ScopeOp::build(::mlir::OpBuilder&,
46+
::mlir::OperationState& odsState,
47+
std::unique_ptr<::mlir::Region> body) {
48+
odsState.addRegion(std::move(body));
49+
}

0 commit comments

Comments
 (0)