Skip to content

Commit 7a48bdf

Browse files
committed
[Heavy] Add DialectRegistry; Fix error propagation in nested lambdas
1 parent 3627031 commit 7a48bdf

File tree

8 files changed

+142
-61
lines changed

8 files changed

+142
-61
lines changed

heavy/include/heavy/Context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
namespace mlir {
3838
class MLIRContext;
39+
class DialectRegistry;
3940
}
4041

4142
namespace heavy {
@@ -110,10 +111,11 @@ class Context : public ContinuationStack<Context>,
110111
Value EnvStack;
111112

112113
public: // Provide access in lib/Mlir bindings.
114+
std::unique_ptr<mlir::DialectRegistry> DialectRegistry;
113115
std::unique_ptr<mlir::MLIRContext> MLIRContext;
114116
private:
115117
SourceLocation Loc = {}; // last known location for errors
116-
Value Err = nullptr;
118+
Value Err = nullptr; // FIXME Remove Context.Err I think.
117119
Value ExceptionHandlers = heavy::Empty();
118120
mlir::Operation* ModuleOp = nullptr;
119121

heavy/include/heavy/OpGen.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
181181

182182
public:
183183
explicit OpGen(heavy::Context& C, heavy::Symbol* ModulePrefix = nullptr);
184+
~OpGen();
184185

185186
heavy::Context& getContext() { return Context; }
186187

@@ -333,6 +334,7 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
333334
void createLoadModule(SourceLocation Loc, Symbol* MangledName);
334335

335336
mlir::Value SetError(heavy::Error* NewErr) {
337+
assert((!Err || Value(NewErr) == Err) && "no squashing errors");
336338
Err = NewErr;
337339
Context.setLoc(Err.getSourceLocation());
338340
if (RunSyncDepth == 0)

heavy/include/nbdl_gen/Dialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#ifndef NBDL_DIALECT_H
14-
#define NBDL_DIALECT_H
13+
#ifndef NBDL_GEN_DIALECT_H
14+
#define NBDL_GEN_DIALECT_H
1515

1616
#include "mlir/Dialect/Func/IR/FuncOps.h"
1717
#include "mlir/IR/BuiltinOps.h"

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 90 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include "mlir/IR/AttrTypeBase.td"
1212
include "mlir/IR/OpBase.td"
1313

1414
def Nbdl_Dialect : Dialect {
15-
let name = "nbdl_gen";
15+
let name = "nbdl";
1616
let useDefaultTypePrinterParser = 1;
1717
}
1818

@@ -28,9 +28,19 @@ class Nbdl_TypeBase<string name, string type_mnemonic,
2828

2929
def Nbdl_OpaqueType : Nbdl_TypeBase<"OpaqueType", "opaque_type"> {
3030
let description = [{
31-
If a C++ type is denoted without any additional semantic
32-
information, then we call it opaque. This can be useful
33-
for visitor functions, keys etc.
31+
Represent an unknown C++ type.
32+
}];
33+
}
34+
35+
def Nbdl_Struct : Nbdl_TypeBase<"Struct", "struct"> {
36+
let summary = "C++ semiregular aggregate type";
37+
let description = [{
38+
Represent C++ semiregular types whose members
39+
are accessed by their name.
40+
41+
Struct is equivalent to State except that there is no
42+
way to access the members without reflection or generated
43+
code.
3444
}];
3545
}
3646

@@ -59,52 +69,80 @@ def Nbdl_Tag : Nbdl_TypeBase<"Tag", "tag_type"> {
5969
}];
6070
}
6171

62-
def Nbdl_Type : AnyTypeOf<[Nbdl_OpaqueType,
63-
Nbdl_Store,
64-
Nbdl_State,
65-
Nbdl_Tag]>;
72+
def Nbdl_Symbol : Nbdl_TypeBase<"Symbol", "symbol_type"> {
73+
let summary = "Nbdl symbol type";
74+
let description = [{
75+
Represent a string that is a valid C++ identifier.
76+
The intended primary use case is to specify a member of a Struct.
77+
}];
78+
}
6679

6780
// Keys are optional, but we are using variadic arguments
6881
// so we resort to using this sum type.
69-
def Nbdl_Key : AnyTypeOf<[Nbdl_OpaqueType, NoneType]>;
82+
def Nbdl_KeyType : AnyTypeOf<[Nbdl_OpaqueType, Nbdl_Tag,
83+
Nbdl_Symbol, NoneType]>;
84+
85+
def Nbdl_Type : AnyTypeOf<[Nbdl_OpaqueType,
86+
Nbdl_Store,
87+
Nbdl_Struct,
88+
Nbdl_State,
89+
Nbdl_Tag]>;
90+
91+
def Nbdl_TagAttr : AttrDef<Nbdl_Dialect, "tag_attr"> {
92+
let attrName = "nbdl.tag_attr";
93+
let parameters = (ins "::mlir::StringAttr":$cppTypeName);
94+
}
95+
96+
def Nbdl_TypenameAttr : AttrDef<Nbdl_Dialect, "typename_attr"> {
97+
let attrName = "nbdl.typename_attr";
98+
let parameters = (ins "::mlir::StringAttr":$cppTypeName);
99+
}
100+
101+
def Nbdl_SymbolAttr : AttrDef<Nbdl_Dialect, "symbol_attr"> {
102+
let attrName = "nbdl.symbol_attr";
103+
let parameters = (ins "::mlir::StringAttr":$cppIdentifier);
104+
}
105+
106+
def Nbdl_KeyAttr : AnyAttrOf<[Nbdl_TagAttr,
107+
Nbdl_TypenameAttr,
108+
Nbdl_SymbolAttr]>;
70109

71110
class Nbdl_Op<string mnemonic, list<Trait> traits = []> :
72111
Op<Nbdl_Dialect, mnemonic, traits>;
73112

74-
def Nbdl_TagOp : Nbdl_Op<"tag", []> {
75-
let summary = "tag";
113+
// FIXME I think KeyOp is redundant with KeyAttr
114+
/*
115+
def Nbdl_KeyOp : Nbdl_Op<"key", []> {
116+
let summary = "key";
76117
let description = [{
77-
Create an instance of a tag_type.
78-
These are intended primarily for use as
79-
keys and allow avoiding capture of stateless
80-
values.
118+
Create an instance of a stateless key object.
119+
These are intended primarily to allow avoiding
120+
capture of stateless values.
81121
}];
82122

83-
let results = (outs Nbdl_Tag:$result);
123+
let arguments = (ins Nbdl_KeyAttr:$key);
124+
let results = (outs Nbdl_Keytype:$result);
84125
}
126+
*/
85127

86128
def Nbdl_GetOp : Nbdl_Op<"get", []> {
87129
let summary = "get";
88130
let description = [{
89131
}];
90132
let arguments = (ins Nbdl_State:$state,
91-
Nbdl_Key:$key);
133+
Nbdl_KeyType:$key);
92134
let results = (outs Nbdl_OpaqueType:$result);
93135
}
94136

95137
def Nbdl_MatchOp : Nbdl_Op<"match", [Terminator, NoTerminator]> {
96138
let summary = "match";
97139
let description = [{
98-
Given a continuation function, a nbdl::Store and, optionally, a key,
140+
Given a nbdl::Store and, optionally, a key,
99141
`match` visits an element contained within that Store using a
100-
continuation for each of the possible typed alternatives.
142+
continuation for each of the possible specified overloads.
101143

102-
Stores match values of different types, so a region is
103-
used to provide the continuation for each possible
104-
alternative which is checked linearly.
105-
106-
It is an error for the matched object to not have a matching
107-
alternative.
144+
It is an error if there exists an alternative for which there
145+
exists no matching overload.
108146

109147
For an example, a `std::unordered_map<int, std::string>` is a store
110148
that can provide access to a contained element with a key `5`. If the
@@ -113,33 +151,39 @@ def Nbdl_MatchOp : Nbdl_Op<"match", [Terminator, NoTerminator]> {
113151
also provides access to its contained element without a key
114152
(ie it is just unwrapped). This requires a visitor to be overloaded
115153
with every possible alternative.
116-
117-
Each region receives fn, store, and the captures as their
118-
arguments.
119154
}];
120-
let arguments = (ins Nbdl_OpaqueType:$fn,
121-
AnyTypeOf<[Nbdl_Store, Nbdl_State]>:$store,
122-
Nbdl_Key:$key,
123-
Variadic<Nbdl_Type>:$captures);
155+
// FIXME $store should only be a Nbdl_Store right?
156+
let arguments = (ins Nbdl_Store:$store,
157+
Nbdl_KeyType:$key);
124158
let results = (outs);
125159
let regions = (region SizedRegion<1>:$overloads);
126160
}
127161

128162
def Nbdl_OverloadOp : Nbdl_Op<"overload", []> {
129163
let description = [{
130164
In the region of a MatchOp we specify a function overload of the
131-
given Nbdl_OpaqueType using OverloadOp.
132-
Each overload is checked linearly (sequentially) and a valid
133-
overload serves as a continuation.
134-
It is an error if no overload matches the matched alternative.
165+
given Nbdl_TypenameAttr using OverloadOp.
166+
The body of the overload takes an argument that is the matched
167+
object and the rest of the arguments are the captures.
135168
}];
136-
let arguments = (ins Nbdl_Type:$arg,
169+
let arguments = (ins Nbdl_TypenameAttr:$typeName,
137170
Variadic<Nbdl_Type>:$captures);
138171
let results = (outs);
172+
let regions = (region SizedRegion<1>:$body);
173+
}
174+
175+
def Nbdl_MatchIfOp : Nbdl_Op<"match_if", [Terminator]> {
176+
let description = [{
177+
Apply a predicate to an object and branch on the result
178+
like an if/else statement.
179+
}];
180+
181+
let arguments = (ins Nbdl_Type:$obj, Nbdl_OpaqueType:$pred);
182+
let regions = (region SizedRegion<1>:$thenRegion,
183+
SizedRegion<1>:$elseRegion);
139184
}
140185

141186
def Nbdl_ContOp : Nbdl_Op<"cont", [Terminator]> {
142-
let summary = "cont";
143187
let description = [{
144188
ContOp represents the call to an opaque continuation object
145189
passing the resolved object and any additional captures
@@ -151,4 +195,12 @@ def Nbdl_ContOp : Nbdl_Op<"cont", [Terminator]> {
151195
let results = (outs);
152196
}
153197

198+
def Nbdl_ConstexprOp : Nbdl_Op<"constexpr", []> {
199+
let description = [{
200+
The result of a constant expression in C++.
201+
}];
202+
let arguments = (ins StrAttr:$expr);
203+
let results = (outs Nbdl_OpaqueType:$result);
204+
}
205+
154206
#endif

heavy/lib/Context.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "heavy/ValueVisitor.h"
2424
#include "mlir/Bytecode/BytecodeReader.h"
2525
#include "mlir/Bytecode/BytecodeWriter.h"
26+
#include "mlir/IR/DialectRegistry.h"
2627
#include "mlir/IR/MLIRContext.h"
2728
#include "mlir/IR/Verifier.h" // TODO move to OpGen
2829
#include "mlir/Support/FileUtilities.h" // openOutputFile
@@ -72,7 +73,8 @@ Context::Context()
7273
ContextLocalLookup(),
7374
Heap(MiB),
7475
EnvStack(Empty()),
75-
MLIRContext(std::make_unique<mlir::MLIRContext>()),
76+
DialectRegistry(std::make_unique<mlir::DialectRegistry>()),
77+
MLIRContext(std::make_unique<mlir::MLIRContext>(*DialectRegistry)),
7678
OpGen(nullptr)
7779
{
7880
NameForImportVar = HEAVY_IMPORT_VAR;
@@ -1351,21 +1353,16 @@ void Context::WithEnv(std::unique_ptr<heavy::Environment> EnvPtr,
13511353
// This is used for *nested* calls in C++ when finishing
13521354
// the operation on the current C++ call stack is needed.
13531355
Value Context::RunSync(Value Callee, Value SingleArg) {
1356+
Value CatchHandler = CreateLambda(
1357+
[](heavy::Context& C, ValueRefs Args) {
1358+
C.Yield(Args);
1359+
});
1360+
Value PrevHandlers = ExceptionHandlers;
1361+
ExceptionHandlers = CatchHandler;
13541362
PushBreak();
1355-
CallCC(CreateLambda([](Context& C, ValueRefs Args) {
1356-
heavy::Value Callee = C.getCapture(0);
1357-
heavy::Value SingleArg = C.getCapture(1);
1358-
Value CC = Args[0];
1359-
heavy::Value Thunk = C.CreateLambda([](heavy::Context& C,
1360-
heavy::ValueRefs) {
1361-
heavy::Value Callee = C.getCapture(0);
1362-
heavy::Value SingleArg = C.getCapture(1);
1363-
C.Apply(Callee, ValueRefs{SingleArg});
1364-
}, CaptureList{Callee, SingleArg});
1365-
heavy::Value HandleError = CC;
1366-
C.WithExceptionHandler(HandleError, Thunk);
1367-
}, CaptureList{Callee, SingleArg}));
1363+
Apply(Callee, ValueRefs{SingleArg});
13681364
Resume();
1365+
ExceptionHandlers = PrevHandlers;
13691366
return getCurrentResult();
13701367
}
13711368

heavy/lib/Mlir.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include <nbdl_gen/Dialect.h>
1314
#include <heavy/Context.h>
1415
#include <heavy/Dialect.h>
1516
#include <heavy/Mlir.h>
@@ -243,6 +244,8 @@ void create_op(Context& C, ValueRefs Args) { // Syntax
243244

244245
// Require the _name_ argument.
245246
mlir::Value OpName = OpGen.GetSingleResult(Input->Car);
247+
if (OpGen.CheckError())
248+
return;
246249

247250
// Process named arguments which are optional.
248251
for (auto [Loc, Arg] : WithSource(Input->Cdr)) {
@@ -275,6 +278,8 @@ void create_op(Context& C, ValueRefs Args) { // Syntax
275278
for (auto [Loc, X] : WithSource(Inputs)) {
276279
C.setLoc(Loc);
277280
mlir::Value V = OpGen.GetSingleResult(X);
281+
if (OpGen.CheckError())
282+
return mlir::Value();
278283
Vals.push_back(V);
279284
}
280285

@@ -291,6 +296,8 @@ void create_op(Context& C, ValueRefs Args) { // Syntax
291296
mlir::Value AttrVals = createInputVector(Attributes);
292297
mlir::Value OperandVals = createInputVector(Operands);
293298
mlir::Value NumRegionsVal = OpGen.GetSingleResult(NumRegions);
299+
if (OpGen.CheckError())
300+
return;
294301
mlir::Value ResultTypeVals = createInputVector(ResultTypes);
295302
mlir::Value SuccessorVals = createInputVector(Successors);
296303

@@ -610,7 +617,7 @@ void with_new_context(heavy::Context& C, heavy::ValueRefs Args) {
610617
if (!Thunk)
611618
return C.RaiseError("expecting thunk");
612619

613-
auto NewContextPtr = std::make_unique<mlir::MLIRContext>();
620+
auto NewContextPtr = std::make_unique<mlir::MLIRContext>(*C.DialectRegistry);
614621
heavy::Value NewMC = CreateTagged(C, kind::mlir_context,
615622
NewContextPtr.get());
616623
heavy::Value NewBuilder = CreateTagged(C, kind::mlir_builder,
@@ -728,12 +735,17 @@ void verify(Context& C, heavy::ValueRefs Args) {
728735
extern "C" {
729736
// initialize the module for run-time independent of the compiler
730737
void HEAVY_MLIR_INIT(heavy::Context& C) {
738+
// TODO Register dialects in their corresponding
739+
// scheme modules instead of here.
740+
C.DialectRegistry->insert<heavy::Dialect,
741+
nbdl::NbdlDialect>();
742+
731743
mlir::MLIRContext* MC = C.MLIRContext.get();
732744
heavy::Value MC_Val = CreateTagged(C, kind::mlir_context, MC);
733745
heavy::Value BuilderVal = CreateTagged(C, kind::mlir_builder,
734746
mlir::OpBuilder(MC));
735-
HEAVY_MLIR_VAR(current_context).init(C, MC_Val);
736-
HEAVY_MLIR_VAR(current_builder).init(C, BuilderVal);
747+
HEAVY_MLIR_VAR(current_context).init(C, C.CreateBinding(MC_Val));
748+
HEAVY_MLIR_VAR(current_builder).init(C, C.CreateBinding(BuilderVal));
737749

738750
HEAVY_MLIR_VAR(create_op) = heavy::mlir_bind::create_op;
739751
HEAVY_MLIR_VAR(create_op_impl) = heavy::mlir_bind::create_op_impl;

heavy/lib/NbdlDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#define GET_OP_CLASSES
2727
#include "nbdl_gen/NbdlOps.cpp.inc"
2828

29-
void nbdl_gen::NbdlDialect::initialize() {
29+
void nbdl::NbdlDialect::initialize() {
3030
addTypes<
3131
#define GET_TYPEDEF_LIST
3232
#include "nbdl_gen/NbdlTypes.cpp.inc"

0 commit comments

Comments
 (0)