Skip to content

Commit 31a40c1

Browse files
committed
[Heavy] Add %build-context
1 parent 951310a commit 31a40c1

File tree

3 files changed

+92
-14
lines changed

3 files changed

+92
-14
lines changed

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def Nbdl_Dialect : Dialect {
1818
let useDefaultTypePrinterParser = 1;
1919
}
2020

21-
// FIXME Maybe we should just use MLIR's NoneType for this.
2221
def Nbdl_Empty : TypeDef<Nbdl_Dialect, "Empty", []> {
2322
let mnemonic = "empty";
2423
let description = [{
@@ -114,7 +113,7 @@ def Nbdl_Symbol : TypeDef<Nbdl_Dialect, "Symbol"> {
114113
// Keys are optional, but we are using variadic arguments
115114
// so we resort to using this sum type.
116115
def Nbdl_KeyType : AnyTypeOf<[Nbdl_Opaque, Nbdl_Tag,
117-
Nbdl_Symbol, NoneType]>;
116+
Nbdl_Symbol, Nbdl_Empty]>;
118117

119118
def Nbdl_Type : AnyTypeOf<[Nbdl_Void, Nbdl_Empty, Nbdl_Opaque,
120119
Nbdl_Store, Nbdl_Tag]>;
@@ -236,7 +235,7 @@ def Nbdl_StoreComposeOp : Nbdl_Op<"store_compose", []> {
236235
}];
237236

238237
let arguments = (ins
239-
TypeAttr:$key,
238+
Nbdl_KeyType:$key,
240239
Nbdl_Type:$lhs,
241240
Nbdl_Type:$rhs);
242241
let results = (outs Nbdl_Type:$result);
@@ -250,7 +249,7 @@ def Nbdl_VariantOp : Nbdl_Op<"variant", []> {
250249
let results = (outs Nbdl_Type:$result);
251250
}
252251

253-
def Nbdl_CreateStoreOp : Nbdl_Op<"create_store", [Symbol, IsolatedFromAbove]> {
252+
def Nbdl_ContextOp : Nbdl_Op<"context", [Symbol, IsolatedFromAbove]> {
254253
let description = [{
255254
Define a store object, and expose its interface (type) to the user.
256255

@@ -270,9 +269,16 @@ def Nbdl_CreateStoreOp : Nbdl_Op<"create_store", [Symbol, IsolatedFromAbove]> {
270269
}];
271270

272271
let regions = (region AnyRegion:$body);
273-
let arguments = (ins
274-
FlatSymbolRefAttr:$name,
275-
OptionalAttr<DictArrayAttr>:$cnstrArgs);
272+
let arguments = (ins StrAttr:$sym_name);
273+
let results = (outs);
274+
}
275+
276+
def Nbdl_ContOp : Nbdl_Op<"cont", [Terminator]> {
277+
let description = [{
278+
The terminator we need.
279+
}];
280+
281+
let arguments = (ins Nbdl_Type:$arg);
276282
let results = (outs);
277283
}
278284

@@ -289,6 +295,15 @@ def Nbdl_StoreOp : Nbdl_Op<"store", []> {
289295
let results = (outs Nbdl_Type:$result);
290296
}
291297

298+
def Nbdl_EmptyOp : Nbdl_Op<"empty", []> {
299+
let description = [{
300+
An operation that produces an object of the empty type.
301+
}];
302+
303+
let arguments = (ins);
304+
let results = (outs Nbdl_Empty:$result);
305+
}
306+
292307
def Nbdl_ApplyOp : Nbdl_Op<"apply", []> {
293308
let description = [{
294309
Apply a function.

heavy/lib/Nbdl.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ heavy::ExternFunction translate_cpp;
3232
heavy::ExternFunction build_match_params_impl;
3333
heavy::ExternFunction build_overload_impl;
3434
heavy::ExternFunction build_match_if_impl;
35+
heavy::ExternFunction build_context_impl;
3536
}
3637

3738
namespace {
@@ -226,7 +227,63 @@ void translate_cpp(Context& C, ValueRefs Args) {
226227
}
227228
C.Cont();
228229
}
230+
231+
void build_context_impl(Context& C, ValueRefs Args) {
232+
if (Args.size() != 3)
233+
return C.RaiseError("invalid arity");
234+
235+
std::optional<mlir::OpBuilder> BuilderOpt = getModuleBuilder(C);
236+
if (!BuilderOpt)
237+
return;
238+
mlir::OpBuilder Builder = BuilderOpt.value();
239+
240+
// This part is exactly like build_match_params_impl
241+
heavy::SourceLocation Loc = Args[0].getSourceLocation();
242+
// Require a heavy::Symbol so it has a source location.
243+
heavy::Symbol* NameSym = dyn_cast<heavy::Symbol>(Args[0]);
244+
llvm::StringRef Name = NameSym->getStringRef();
245+
heavy::Value NumParamsVal = Args[1];
246+
heavy::Value Callback = Args[2];
247+
248+
int32_t NumParams = isa<heavy::Int>(NumParamsVal)
249+
? int32_t(cast<heavy::Int>(NumParamsVal))
250+
: int32_t(-1);
251+
252+
if (NumParams < 0)
253+
return C.RaiseError("expecting positive integer for num_params");
254+
if (!NameSym || Name.empty())
255+
return C.RaiseError("expecting function name (symbol literal)");
256+
257+
mlir::Location MLoc = mlir::OpaqueLoc::get(Loc.getOpaqueEncoding(),
258+
Builder.getContext());
259+
// Create the ContextOp.
260+
auto ContextOp = Builder.create<nbdl_gen::ContextOp>(MLoc, Name);
261+
mlir::Block& EntryBlock = ContextOp.getBody().emplaceBlock();
262+
263+
// Create the arguments.
264+
auto OpaqueTy = Builder.getType<nbdl_gen::OpaqueType>();
265+
for (int32_t i = 0; i < NumParams; i++)
266+
EntryBlock.addArgument(OpaqueTy, MLoc);
267+
268+
heavy::Value Thunk = C.CreateLambda([ContextOp](Context& C,
269+
ValueRefs) mutable {
270+
heavy::Value Callback = C.getCapture(0);
271+
llvm::SmallVector<heavy::Value, 8> BlockArgs;
272+
assert(!ContextOp.getBody().empty() && "should have entry block");
273+
for (mlir::Value MVal : ContextOp.getBody().getArguments()) {
274+
heavy::Value V = mlir_helper::createTagged(C,
275+
mlir_helper::kind::mlir_value, MVal);
276+
BlockArgs.push_back(V);
277+
}
278+
279+
C.Apply(Callback, BlockArgs);
280+
}, CaptureList{Callback});
281+
282+
// Call the thunk with a Builder at the entry point.
283+
Builder = mlir::OpBuilder(ContextOp.getBody());
284+
mlir_helper::with_builder_impl(C, Builder, Thunk);
229285
}
286+
} // end namespace heavy::nbdl_bind
230287

231288
extern "C" {
232289
// initialize the module for run-time independent of the compiler
@@ -247,6 +304,8 @@ void HEAVY_NBDL_INIT(heavy::Context& C) {
247304
= heavy::nbdl_bind::build_overload_impl;
248305
heavy::nbdl_bind_var::build_match_if_impl
249306
= heavy::nbdl_bind::build_match_if_impl;
307+
heavy::nbdl_bind_var::build_context_impl
308+
= heavy::nbdl_bind::build_context_impl;
250309
}
251310

252311
void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
@@ -257,6 +316,7 @@ void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
257316
{"%build-match-params", heavy::nbdl_bind_var::build_match_params_impl},
258317
{"%build-overload", heavy::nbdl_bind_var::build_overload_impl},
259318
{"%build-match-if", heavy::nbdl_bind_var::build_match_if_impl},
319+
{"%build-context", heavy::nbdl_bind_var::build_context_impl},
260320
});
261321
}
262322
}

heavy/lib/NbdlWriter.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class NbdlWriter {
4545
heavy::SourceLocationEncoding* ErrLoc;
4646
llvm::raw_ostream& OS;
4747

48-
// Track number of members for CreateStoreOp
48+
// Track number of members of context
4949
// to generate anonymous identifiers if needed.
5050
unsigned CurrentMemberCount = 0;
5151
unsigned CurrentAnonVarCount = 0;
@@ -134,7 +134,7 @@ class NbdlWriter {
134134
if (CheckError())
135135
return;
136136

137-
if (isa<CreateStoreOp>(Op)) return Visit(cast<CreateStoreOp>(Op));
137+
if (isa<ContextOp>(Op)) return Visit(cast<ContextOp>(Op));
138138
else if (isa<StoreOp>(Op)) return Visit(cast<StoreOp>(Op));
139139
else if (isa<ApplyOp>(Op)) return Visit(cast<ApplyOp>(Op));
140140
else if (isa<GetOp>(Op)) return Visit(cast<GetOp>(Op));
@@ -161,7 +161,7 @@ class NbdlWriter {
161161
Visit(&Op);
162162
}
163163

164-
void Visit(CreateStoreOp Op) {
164+
void Visit(ContextOp Op) {
165165
// Skip externally defined stores.
166166
if (Op.isExternal())
167167
return;
@@ -197,7 +197,7 @@ class NbdlWriter {
197197
#if 0
198198
if (isTopLevel()) {
199199
// Add the RHS as a member.
200-
assert(isa<nbdl_gen::CreateStoreOp>(Op.getParentOp()) &&
200+
assert(isa<nbdl_gen::ContextOp>(Op.getParentOp()) &&
201201
"should be in context of creating a store");
202202

203203
// Temporarily store anonymous member name if needed.
@@ -385,8 +385,8 @@ class NbdlWriter {
385385
************************************/
386386

387387
void VisitType(mlir::Operation* Op) {
388-
if (isa<CreateStoreOp>(Op))
389-
return VisitType(cast<CreateStoreOp>(Op));
388+
if (isa<ContextOp>(Op))
389+
return VisitType(cast<ContextOp>(Op));
390390
else if (isa<StoreOp>(Op)) return VisitType(cast<StoreOp>(Op));
391391
else if (isa<VariantOp>(Op)) return VisitType(cast<VariantOp>(Op));
392392
else if (isa<StoreComposeOp>(Op))
@@ -396,7 +396,7 @@ class NbdlWriter {
396396
SetError("unhandled operation (VisitType)", Op);
397397
}
398398

399-
void VisitType(CreateStoreOp Op) {
399+
void VisitType(ContextOp Op) {
400400
OS << Op.getName();
401401
}
402402

@@ -415,6 +415,8 @@ class NbdlWriter {
415415
}
416416

417417
void VisitType(StoreComposeOp Op) {
418+
llvm_unreachable("TODO");
419+
#if 0
418420
OS << "nbdl::store_composite<";
419421

420422
// Key
@@ -427,6 +429,7 @@ class NbdlWriter {
427429
OS << ", ";
428430
VisitType(Op.getLoc(), Op.getLhs());
429431
OS << ">";
432+
#endif
430433
}
431434

432435
void VisitType(ConstexprOp Op) {

0 commit comments

Comments
 (0)