Skip to content

Commit 10c1b39

Browse files
committed
[Heavy] Add DefineStoreOp
1 parent d12d8ed commit 10c1b39

File tree

9 files changed

+215
-229
lines changed

9 files changed

+215
-229
lines changed

heavy/include/heavy/Value.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,10 +1399,17 @@ class Any final :
13991399

14001400
// Note that for stored pointers, this will return a pointer to a pointer.
14011401
void* getOpaquePtr() { return getTrailingObjects<char>(); }
1402+
14021403
llvm::StringRef getObjData() {
14031404
return llvm::StringRef(static_cast<char*>(getOpaquePtr()), getObjectSize());
14041405
};
14051406

1407+
bool equal(Any* Other) {
1408+
return TypeId == Other->TypeId &&
1409+
StorageLen == Other->StorageLen &&
1410+
getObjData() == Other->getObjData();
1411+
}
1412+
14061413
template <typename T>
14071414
bool isa() {
14081415
return TypeId == &AnyTypeId<T>::Id;

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ def Nbdl_Store : TypeDef<Nbdl_Dialect, "Store", []> {
7171
}];
7272
}
7373

74-
def Nbdl_Variant : TypeDef<Nbdl_Dialect, "Variant", []> {
75-
let mnemonic = "variant";
76-
let description = [{
77-
A type satisfying the nbdl::Variant concept
78-
}];
79-
}
80-
8174
def Nbdl_Tag : TypeDef<Nbdl_Dialect, "Tag", []> {
8275
let mnemonic = "tag";
8376
let description = [{
@@ -86,6 +79,21 @@ def Nbdl_Tag : TypeDef<Nbdl_Dialect, "Tag", []> {
8679
}];
8780
}
8881

82+
def Nbdl_Variant : TypeDef<Nbdl_Dialect, "Variant", []> {
83+
let mnemonic = "variant";
84+
let description = [{
85+
A type satisfying the nbdl::Variant concept
86+
}];
87+
88+
/*
89+
// TODO Eventually it would be nice to specify the alternatives
90+
// Also some parameterization for knowing the key =>
91+
let parameters = (ins Variadic<AnyTypeOf<[Nbdl_Store, Nbdl_Tag,
92+
Nbdl_Unit>:$alts);
93+
let assemblyFormat = "`<`$alts`>`";
94+
*/
95+
}
96+
8997
def Nbdl_Symbol : TypeDef<Nbdl_Dialect, "Symbol"> {
9098
let mnemonic = "symbol";
9199
let description = [{
@@ -236,7 +244,7 @@ def Nbdl_VariantOp : Nbdl_Op<"variant", []> {
236244
let results = (outs Nbdl_Type:$result);
237245
}
238246

239-
def Nbdl_ContextOp : Nbdl_Op<"context", [Symbol, IsolatedFromAbove]> {
247+
def Nbdl_DefineStoreOp : Nbdl_Op<"define_store", [Symbol, IsolatedFromAbove]> {
240248
let description = [{
241249
Define a store object, and expose its interface (type) to the user.
242250

@@ -268,6 +276,33 @@ def Nbdl_ContextOp : Nbdl_Op<"context", [Symbol, IsolatedFromAbove]> {
268276
}];
269277
}
270278

279+
def Nbdl_ContextOp : Nbdl_Op<"context", [Symbol]> {
280+
let description = [{
281+
Encapsulate a store object to serve as a root node in a state tree.
282+
283+
Nested objects are still accessible via `match` and `get`,
284+
but the context wrapper supports the convention that the root
285+
node cannot be a non-root member of a path.
286+
287+
An instance can be created with nbdl.store.
288+
}];
289+
290+
let arguments = (ins StrAttr:$sym_name, FlatSymbolRefAttr:$implName);
291+
let results = (outs);
292+
293+
let extraClassDeclaration = [{
294+
// Get the fully qualified name.
295+
// This is needed for defining template specializations.
296+
llvm::StringRef getFullName() { return getSymName(); }
297+
298+
// Get the unqualified name.
299+
llvm::StringRef getName() {
300+
llvm::StringRef SymName = getSymName();
301+
return SymName.take_back(SymName.rfind(':'));
302+
}
303+
}];
304+
}
305+
271306
def Nbdl_ContOp : Nbdl_Op<"cont", [Terminator]> {
272307
let description = [{
273308
The terminator we need.

heavy/lib/Context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,8 @@ bool eqv_slow(Value V1, Value V2) {
727727
case ValueKind::Float:
728728
return cast<Float>(V1)->getVal() ==
729729
cast<Float>(V2)->getVal();
730+
case ValueKind::Any:
731+
return cast<Any>(V1)->equal(cast<Any>(V2));
730732
default:
731733
return false;
732734
}

heavy/lib/Nbdl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ void translate_cpp(Context& C, ValueRefs Args) {
215215
return C.RaiseError("invalid arity");
216216
auto* Op = dyn_cast<mlir::Operation>(Args[0]);
217217
if (!Op)
218-
return C.RaiseError("expecting mlir.operation");
218+
return C.RaiseError("expecting mlir.operation: {}", Args[0]);
219219

220220
llvm::raw_ostream* OS = nullptr;
221221

@@ -285,6 +285,8 @@ void close_previous_scope(Context& C, ValueRefs Args) {
285285
}
286286

287287
void build_context_impl(Context& C, ValueRefs Args) {
288+
return C.RaiseError("deprecated");
289+
#if 0
288290
if (Args.size() != 3)
289291
return C.RaiseError("invalid arity");
290292

@@ -337,6 +339,7 @@ void build_context_impl(Context& C, ValueRefs Args) {
337339
// Call the thunk with a Builder at the entry point.
338340
Builder = mlir::OpBuilder(ContextOp.getBody());
339341
mlir_helper::with_builder_impl(C, Builder, Thunk);
342+
#endif
340343
}
341344

342345

heavy/lib/Nbdl/NbdlWriter.cpp

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class NbdlWriter {
301301
}
302302

303303
void VisitType(ContextOp Op) {
304-
OS << Op.getName();
304+
OS << Op.getSymName();
305305
}
306306

307307
void VisitType(StoreOp Op) {
@@ -560,43 +560,92 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
560560

561561
class ContextWriter : public NbdlWriter<ContextWriter> {
562562
public:
563-
// Record the values for each member in order.
564-
llvm::SmallVector<mlir::Value, 8> Members;
565-
566563
using NbdlWriter<ContextWriter>::NbdlWriter;
567564

568565
void VisitContext(ContextOp Op) {
569566
SetLoc(Op.getLoc());
567+
568+
llvm::StringRef ClassName = Op.getName();
569+
llvm::StringRef ImplName = Op.getImplName();
570+
571+
OS << "class " << ClassName << " : nbdl::context_alias<";
572+
OS << ImplName;
573+
OS << ", /*is_moveable=*/false> {\n"
574+
"using Base = ";
575+
OS << ImplName;
576+
OS << ";\n using Base::Base;\n";
577+
OS << "};\n";
578+
Flush();
579+
}
580+
};
581+
582+
class DefineStoreWriter : public NbdlWriter<DefineStoreWriter> {
583+
public:
584+
// Record the values for each member in order.
585+
llvm::SmallVector<mlir::Value, 8> Members;
586+
587+
using NbdlWriter<DefineStoreWriter>::NbdlWriter;
588+
589+
void VisitDefineStore(DefineStoreOp Op) {
570590
// Skip externally defined stores.
571591
if (Op.isExternal())
572592
return;
573593

594+
SetLoc(Op.getLoc());
574595
ValueMapScope Scope(ValueMap);
575596

576597
// Set the arg names first.
577598
for (mlir::BlockArgument BlockArg : Op.getBody().getArguments())
578599
SetLocalVarName(BlockArg, "arg_");
579600

580-
// Delete both copy constructors to support subsumption with `auto&&`.
581-
OS << "class " << Op.getName() << " {\n";
601+
auto ContOp = getContOp(Op);
602+
llvm::TypeSwitch<mlir::Operation*>(ContOp.getArg().getDefiningOp())
603+
.Case<UnitOp>([&, this](auto) {
604+
this->CreateTag(Op.getName());
605+
})
606+
.Case<StoreOp, VariantOp>([&, this](auto ResultOp) {
607+
this->CreateStrongAlias(Op.getName(), ResultOp.getResult());
608+
})
609+
.Case<StoreComposeOp>([&, this](auto) {
610+
this->CreateClass(Op);
611+
});
612+
}
613+
614+
void CreateTag(llvm::StringRef Name) {
615+
OS << "struct " << Name << " { };\n";
616+
}
617+
618+
void CreateStrongAlias(llvm::StringRef Name, mlir::Value V) {
619+
OS << "class " << Name << " : public nbdl::strong_alias<";
620+
VisitType(V);
621+
OS << "> {\n"
622+
"using Base = ";
623+
VisitType(V);
624+
OS << ";\n using Base::Base;\n";
625+
OS << "};\n";
626+
}
627+
628+
void CreateClass(DefineStoreOp Op) {
629+
llvm::StringRef Name = Op.getName();
630+
OS << "class " << Name << " {\n";
582631
OS << "public:\n";
583632
WriteMemberDecls(Op);
584-
OS << Op.getName() << '(' << Op.getName() << " const&) = delete;\n";
585-
OS << Op.getName() << '(' << Op.getName() << "&) = delete;\n";
633+
OS << Name << "(" << Name << " const&) = default;\n";
634+
OS << Name << "(" << Name << "&&) = default;\n";
586635
WriteConstructor(Op);
587636
OS << "};\n";
588637
Flush();
589638
}
590639

591-
nbdl_gen::ContOp getContOp(ContextOp Op) {
640+
nbdl_gen::ContOp getContOp(DefineStoreOp Op) {
592641
mlir::Operation* Terminator = Op.getBody().front().getTerminator();
593642
auto ContOp = dyn_cast<nbdl_gen::ContOp>(Terminator);
594643
if (!ContOp)
595644
SetError("expecting nbdl.cont as terminator", Op);
596645
return ContOp;
597646
}
598647

599-
void WriteMemberDecls(ContextOp Op) {
648+
void WriteMemberDecls(DefineStoreOp Op) {
600649
// Get the ContOp and work backwards
601650
// saving the member names as we go.
602651
auto ContOp = getContOp(Op);
@@ -651,12 +700,12 @@ class ContextWriter : public NbdlWriter<ContextWriter> {
651700
CurLoc = PrevLoc;
652701
}
653702

654-
void WriteConstructor(ContextOp Op) {
703+
void WriteConstructor(DefineStoreOp Op) {
655704
auto ContOp = getContOp(Op);
656705
if (!ContOp)
657706
return;
658707

659-
OS << Op.getName();
708+
OS << "explicit " << Op.getName();
660709
OS << '(';
661710
llvm::interleaveComma(Op.getBody().getArguments(), OS,
662711
[&](mlir::BlockArgument const& Arg) {
@@ -748,6 +797,11 @@ translate_cpp(heavy::LexerWriterFnRef LexerWriter, mlir::Operation* Op) {
748797
Writer.Visit(Op);
749798
return std::make_tuple(std::move(Writer.ErrMsg),
750799
Writer.ErrLoc, Writer.Irritant);
800+
} else if (auto DefineStoreOp = dyn_cast<nbdl_gen::DefineStoreOp>(Op)) {
801+
DefineStoreWriter Writer(LexerWriter);
802+
Writer.VisitDefineStore(DefineStoreOp);
803+
return std::make_tuple(std::move(Writer.ErrMsg),
804+
Writer.ErrLoc, Writer.Irritant);
751805
} else if (auto ContextOp = dyn_cast<nbdl_gen::ContextOp>(Op)) {
752806
ContextWriter Writer(LexerWriter);
753807
Writer.VisitContext(ContextOp);

heavy/test/Nbdl/apply_action.scm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
(operands: Store Num42)
4040
(attributes:)
4141
(result-types: !nbdl.store))))
42+
(close-previous-scope) ; Should do nothing.
4243
(create-op "nbdl.match"
4344
(loc: 0)
4445
(operands: Store UnitKey)
@@ -59,6 +60,7 @@
5960
(result-types:))))))
6061
; Foo is not allowed after this (because it could be invalidated.)
6162
(close-previous-scope)
63+
(close-previous-scope) ; Should do nothing.
6264
(let ((SomeTag (result
6365
(create-op "nbdl.constexpr"
6466
(loc: 0)

heavy/test/Nbdl/context.cpp

Lines changed: 0 additions & 97 deletions
This file was deleted.

0 commit comments

Comments
 (0)