Skip to content

Commit b4c5d53

Browse files
committed
[Heavy] Add bindings for creating function types and module lookup
1 parent 35fd558 commit b4c5d53

File tree

4 files changed

+127
-50
lines changed

4 files changed

+127
-50
lines changed

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def Nbdl_Void : TypeDef<Nbdl_Dialect, "Void", []> {
4040

4141
// We keep cpp_typename around for when
4242
// we want to go back to c++.
43+
// DEPRECATED
4344
class Nbdl_TypeBase<string name, string type_mnemonic,
4445
list<Trait> traits = []>
4546
: TypeDef<Nbdl_Dialect, name, traits> {
@@ -48,7 +49,8 @@ class Nbdl_TypeBase<string name, string type_mnemonic,
4849
let assemblyFormat = "`<` $cpp_typename `>`";
4950
}
5051

51-
def Nbdl_Opaque : Nbdl_TypeBase<"Opaque", "opaque"> {
52+
def Nbdl_Opaque : TypeDef<Nbdl_Dialect, "Opaque", []> {
53+
let mnemonic = "opaque";
5254
let description = [{
5355
Represent an unknown C++ type.
5456
}];
@@ -64,46 +66,41 @@ def Nbdl_Struct : Nbdl_TypeBase<"Struct", "struct"> {
6466
}
6567
*/
6668

67-
/* FIXME This is not useful without the key.
68-
def Nbdl_State : Nbdl_TypeBase<"State", "state"> {
69-
let summary = "Nbdl State type";
69+
def Nbdl_State : TypeDef<Nbdl_Dialect, "State", []> {
70+
let mnemonic = "state";
7071
let description = [{
71-
In Nbdl, a State object uses a static key to access
72-
elements where paths are resolved at compile-time.
72+
A type satisfying the nbdl::State concept
7373
}];
7474
}
75-
*/
7675

77-
def Nbdl_Store : Nbdl_TypeBase<"Store", "store"> {
78-
let summary = "Nbdl Store type";
76+
def Nbdl_Store : TypeDef<Nbdl_Dialect, "Store", []> {
77+
let mnemonic = "store";
7978
let description = [{
80-
In Nbdl, a Store object has state that may have multiple
81-
alternative paths that are determined at run-time.
82-
A State object is also a Store.
79+
A type satisfying the nbdl::Store concept
8380
}];
84-
// TODO Stores should have a list of keys that we know they match.
8581
}
8682

87-
def Nbdl_Variant : Nbdl_TypeBase<"Variant", "variant"> {
88-
let summary = "Nbdl Variant type";
83+
def Nbdl_Variant : TypeDef<Nbdl_Dialect, "Variant", []> {
84+
let mnemonic = "variant";
8985
let description = [{
90-
Variant is a store that is a sum type with a type index.
86+
A type satisfying the nbdl::Variant concept
9187
}];
9288
}
9389

94-
def Nbdl_Tag : Nbdl_TypeBase<"Tag", "tag_type"> {
95-
let summary = "Nbdl tag type";
90+
def Nbdl_Tag : TypeDef<Nbdl_Dialect, "Tag", []> {
91+
let mnemonic = "tag";
9692
let description = [{
97-
In Nbdl, an empty type that only carries information
93+
An empty type that only carries information
9894
at compile-time is called a tag type.
9995
}];
10096
}
10197

102-
def Nbdl_Symbol : Nbdl_TypeBase<"Symbol", "symbol_type"> {
103-
let summary = "Nbdl symbol type";
98+
def Nbdl_Symbol : Nbdl_TypeBase<"Symbol", "symbol"> {
10499
let description = [{
105100
Represent a string that is a valid C++ identifier.
106101
The intended primary use case is to specify a member of a Struct.
102+
103+
FIXME using this?
107104
}];
108105
}
109106

@@ -113,7 +110,7 @@ def Nbdl_KeyType : AnyTypeOf<[Nbdl_Opaque, Nbdl_Tag,
113110
Nbdl_Symbol, NoneType]>;
114111

115112
def Nbdl_Type : AnyTypeOf<[Nbdl_Void, Nbdl_Empty, Nbdl_Opaque,
116-
Nbdl_Store, Nbdl_Tag]>;
113+
Nbdl_Store, Nbdl_State, Nbdl_Tag]>;
117114

118115
def Nbdl_TagAttr : AttrDef<Nbdl_Dialect, "tag_attr"> {
119116
let attrName = "nbdl.tag_attr";

heavy/lib/Mlir.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ heavy::ExternFunction block_op;
4848
heavy::ExternFunction set_insertion_point;
4949
heavy::ExternFunction set_insertion_after;
5050
heavy::ExternFunction type;
51+
heavy::ExternFunction function_type_impl;
5152
heavy::ExternFunction attr;
53+
heavy::ExternFunction type_attr;
5254
heavy::ExternFunction value_attr;
5355
template <typename AttrTy>
5456
heavy::ExternFunction string_attr;
@@ -58,6 +60,7 @@ heavy::ExternFunction load_dialect;
5860
heavy::ExternFunction parent_op;
5961
heavy::ExternFunction op_next;
6062
heavy::ExternFunction verify;
63+
heavy::ExternFunction module_lookup;
6164
}
6265

6366
namespace {
@@ -577,6 +580,45 @@ void type(Context& C, ValueRefs Args) {
577580
C.Cont(CreateTagged(C, kind::mlir_type, Type.getImpl()));
578581
}
579582

583+
// Create a function type (using vector literals.
584+
// (%function-type #(<arg-types>...) #(<result-types>...))
585+
void function_type_impl(Context& C, ValueRefs Args) {
586+
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
587+
if (Args.size() != 2)
588+
return C.RaiseError("invalid arity");
589+
auto* ArgTypeVals = heavy::dyn_cast<heavy::Vector>(Args[0]);
590+
auto* ResultTypeVals = heavy::dyn_cast<heavy::Vector>(Args[1]);
591+
if (!ArgTypeVals || !ResultTypeVals)
592+
return C.RaiseError("expecting vectors");
593+
594+
// Arg types
595+
llvm::SmallVector<mlir::Type, 8> ArgTypes;
596+
Args = Args.drop_front();
597+
for (heavy::Value Arg : ArgTypeVals->getElements()) {
598+
mlir::Type ArgType = GetTagged<mlir::Type>(C, kind::mlir_type, Arg);
599+
if (!ArgType)
600+
return C.RaiseError("expecting mlir.type");
601+
ArgTypes.push_back(ArgType);
602+
}
603+
604+
// Result types
605+
llvm::SmallVector<mlir::Type, 8> ResultTypes;
606+
for (heavy::Value Result : ResultTypeVals->getElements()) {
607+
mlir::Type ResultType = GetTagged<mlir::Type>(C, kind::mlir_type, Result);
608+
if (!ResultType)
609+
return C.RaiseError("expecting mlir.type");
610+
ResultTypes.push_back(ResultType);
611+
}
612+
613+
mlir::Type Type = mlir::FunctionType::get(MLIRContext, ArgTypes,
614+
ResultTypes);
615+
616+
if (!Type)
617+
return C.RaiseError("mlir build function type failed");
618+
619+
C.Cont(CreateTagged(C, kind::mlir_type, Type.getImpl()));
620+
}
621+
580622
// Get an attribute by parsing a string.
581623
// Usage: (attr _attr_str [_type_])
582624
// attr_str - the string to be parsed
@@ -621,6 +663,17 @@ void attr(Context& C, ValueRefs Args) {
621663
C.Cont(CreateTagged(C, kind::mlir_attr, Attr.getImpl()));
622664
}
623665

666+
// (type-attr (type "!type-goes-here"))
667+
void type_attr(Context& C, ValueRefs Args) {
668+
if (Args.size() != 1)
669+
return C.RaiseError("invalid arity");
670+
mlir::Type Type = GetTagged<mlir::Type>(C, kind::mlir_type, Args[0]);
671+
if (!Type)
672+
return C.RaiseError("expecting a mlir.type");
673+
mlir::Attribute Attr = mlir::TypeAttr::get(Type);
674+
C.Cont(CreateTagged(C, kind::mlir_attr, Attr.getImpl()));
675+
}
676+
624677
// Create a heavy scheme value attribute of type !heavy.value
625678
void value_attr(Context& C, ValueRefs Args) {
626679
if (Args.size() != 1)
@@ -769,6 +822,24 @@ void verify(Context& C, heavy::ValueRefs Args) {
769822
}
770823
}
771824

825+
// (symbol-table-lookup _module-op_ _"symbolname"_)
826+
void module_lookup(Context& C, heavy::ValueRefs Args) {
827+
if (Args.size() != 2)
828+
return C.RaiseError("invalid arity");
829+
830+
mlir::ModuleOp ModuleOp = dyn_cast_or_null<mlir::ModuleOp>(
831+
dyn_cast<mlir::Operation>(Args[0]));
832+
llvm::StringRef SymbolName = Args[1].getStringRef();
833+
if (!ModuleOp)
834+
return C.RaiseError("expecting mlir::ModuleOp");
835+
if (SymbolName.empty())
836+
return C.RaiseError("expecting nonempty string-like object");
837+
838+
mlir::Operation* Op = ModuleOp.lookupSymbol(SymbolName);
839+
heavy::Value Result = Op != nullptr ? heavy::Value(Op) : heavy::Empty();
840+
C.Cont(Result);
841+
}
842+
772843
} // namespace heavy::mlir_bind
773844

774845
extern "C" {
@@ -795,7 +866,9 @@ void HEAVY_MLIR_INIT(heavy::Context& C) {
795866
HEAVY_MLIR_VAR(set_insertion_point) = heavy::mlir_bind::set_insertion_point;
796867
HEAVY_MLIR_VAR(set_insertion_after) = heavy::mlir_bind::set_insertion_after;
797868
HEAVY_MLIR_VAR(type) = heavy::mlir_bind::type;
869+
HEAVY_MLIR_VAR(function_type_impl) = heavy::mlir_bind::function_type_impl;
798870
HEAVY_MLIR_VAR(attr) = heavy::mlir_bind::attr;
871+
HEAVY_MLIR_VAR(type_attr) = heavy::mlir_bind::type_attr;
799872
HEAVY_MLIR_VAR(value_attr) = heavy::mlir_bind::value_attr;
800873
HEAVY_MLIR_VAR(string_attr<mlir::StringAttr>)
801874
= heavy::mlir_bind::string_attr<mlir::StringAttr>;
@@ -805,6 +878,7 @@ void HEAVY_MLIR_INIT(heavy::Context& C) {
805878
HEAVY_MLIR_VAR(with_builder) = heavy::mlir_bind::with_builder;
806879
HEAVY_MLIR_VAR(with_new_context) = heavy::mlir_bind::with_new_context;
807880
HEAVY_MLIR_VAR(verify) = heavy::mlir_bind::verify;
881+
HEAVY_MLIR_VAR(module_lookup) = heavy::mlir_bind::module_lookup;
808882
}
809883

810884
void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
@@ -825,7 +899,9 @@ void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
825899
{"set-insertion-point", HEAVY_MLIR_VAR(set_insertion_point)},
826900
{"set-insertion-after", HEAVY_MLIR_VAR(set_insertion_after)},
827901
{"type", HEAVY_MLIR_VAR(type)},
902+
{"%function-type", HEAVY_MLIR_VAR(function_type_impl)},
828903
{"attr", HEAVY_MLIR_VAR(attr)},
904+
{"type-attr", HEAVY_MLIR_VAR(type_attr)},
829905
{"value-attr", HEAVY_MLIR_VAR(value_attr)},
830906
{"string-attr", HEAVY_MLIR_VAR(string_attr<mlir::StringAttr>)},
831907
{"flat-symbolref-attr",
@@ -834,6 +910,7 @@ void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
834910
{"with-builder", HEAVY_MLIR_VAR(with_builder)},
835911
{"load-dialect", HEAVY_MLIR_VAR(load_dialect)},
836912
{"verify", HEAVY_MLIR_VAR(verify)},
913+
{"module-lookup", HEAVY_MLIR_VAR(module_lookup)},
837914
});
838915
}
839916
}

heavy/lib/Nbdl.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,47 @@
1919
#include <tuple>
2020

2121
namespace nbdl_gen {
22-
std::tuple<std::string, mlir::Location, mlir::Operation*>
22+
std::tuple<std::string, heavy::SourceLocationEncoding*, mlir::Operation*>
2323
translate_cpp(llvm::raw_ostream& OS, mlir::Operation* Op);
2424
}
2525

2626
namespace heavy::nbdl_bind_var {
27-
heavy::ExternFunction write_cpp;
27+
heavy::ExternFunction translate_cpp;
2828
}
2929

3030
namespace heavy::nbdl_bind {
3131
// Translate a nbdl dialect operation to C++.
32-
void write_cpp(Context& C, ValueRefs Args) {
33-
if (Args.size() != 2)
32+
// (translate-cpp op port)
33+
// Currently the "port" has to be a tagged llvm::raw_ostream.
34+
void translate_cpp(Context& C, ValueRefs Args) {
35+
if (Args.size() != 2 && Args.size() != 1)
3436
return C.RaiseError("invalid arity");
3537
auto* Op = dyn_cast<mlir::Operation>(Args[0]);
36-
// Do not captured the emphemeral Tagged object.
37-
auto* Tagged = dyn_cast<heavy::Tagged>(Args[1]);
38-
heavy::Symbol* KindSym = C.CreateSymbol("::llvm::raw_ostream");
3938
if (!Op)
4039
return C.RaiseError("expecting mlir.operation");
41-
if (!Tagged || !Tagged->isa(KindSym))
42-
return C.RaiseError("expecting ::llvm::raw_ostream");
4340

44-
auto& OS = Tagged->cast<llvm::raw_ostream>();
45-
auto&& [ErrMsg, ErrLoc, Irritant] = nbdl_gen::translate_cpp(OS, Op);
41+
llvm::raw_ostream* OS = nullptr;
42+
43+
// Do not capture the emphemeral Tagged object.
44+
if (Args.size() == 2) {
45+
auto* Tagged = dyn_cast<heavy::Tagged>(Args[1]);
46+
heavy::Symbol* KindSym = C.CreateSymbol("::llvm::raw_ostream");
47+
if (!Tagged || !Tagged->isa(KindSym))
48+
return C.RaiseError("expecting ::llvm::raw_ostream");
49+
OS = &(Tagged->cast<llvm::raw_ostream>());
50+
} else {
51+
OS = &llvm::outs();
52+
}
53+
54+
auto&& [ErrMsg, ErrLoc, Irritant] = nbdl_gen::translate_cpp(*OS, Op);
4655
if (!ErrMsg.empty()) {
47-
auto Loc = heavy::SourceLocation(mlir::OpaqueLoc
48-
::getUnderlyingLocationOrNull<heavy::SourceLocationEncoding*>(
49-
mlir::dyn_cast<mlir::OpaqueLoc>(ErrLoc)));
56+
heavy::SourceLocation Loc(ErrLoc);
57+
#if 0
58+
if (mlir::isa<mlir::OpaqueLoc>(ErrLoc))
59+
Loc = heavy::SourceLocation(mlir::OpaqueLoc
60+
::getUnderlyingLocationOrNull<heavy::SourceLocationEncoding*>(
61+
mlir::dyn_cast<mlir::OpaqueLoc>(ErrLoc)));
62+
#endif
5063
heavy::Error* Err = C.CreateError(Loc, ErrMsg,
5164
Irritant ? heavy::Value(Irritant) : Undefined());
5265
return C.Raise(Err);
@@ -60,13 +73,13 @@ extern "C" {
6073
void HEAVY_NBDL_INIT(heavy::Context& C) {
6174
C.DialectRegistry->insert<nbdl_gen::NbdlDialect>();
6275

63-
heavy::nbdl_bind_var::write_cpp = heavy::nbdl_bind::write_cpp;
76+
heavy::nbdl_bind_var::translate_cpp = heavy::nbdl_bind::translate_cpp;
6477
}
6578

6679
void HEAVY_NBDL_LOAD_MODULE(heavy::Context& C) {
6780
HEAVY_NBDL_INIT(C);
6881
heavy::initModuleNames(C, HEAVY_NBDL_LIB_STR, {
69-
{"write-cpp", heavy::nbdl_bind_var::write_cpp},
82+
{"translate-cpp", heavy::nbdl_bind_var::translate_cpp},
7083
});
7184
}
7285
}

heavy/lib/NbdlWriter.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ class NbdlWriter {
264264
}
265265

266266
void VisitType(StoreOp Op) {
267-
mlir::Value Result = Op.getResult();
268-
VisitType(Op.getLoc(), Result);
267+
OS << Op.getName();
269268
}
270269

271270
void VisitType(VariantOp Op) {
@@ -307,16 +306,7 @@ class NbdlWriter {
307306
}
308307

309308
void VisitType(mlir::Location Loc, mlir::Type Type) {
310-
if (auto OpaqueType = dyn_cast<nbdl_gen::OpaqueType>(Type))
311-
OS << OpaqueType.getCppTypename();
312-
else if (auto StoreType = dyn_cast<nbdl_gen::StoreType>(Type))
313-
OS << StoreType.getCppTypename();
314-
else if (auto VariantType = dyn_cast<nbdl_gen::VariantType>(Type))
315-
OS << VariantType.getCppTypename();
316-
else if (auto TagType = dyn_cast<nbdl_gen::TagType>(Type))
317-
OS << TagType.getCppTypename();
318-
else
319-
SetError(Loc, "unprintable type");
309+
SetError(Loc, "unprintable type");
320310
}
321311
};
322312
}

0 commit comments

Comments
 (0)