Skip to content

Commit eb21b73

Browse files
committed
[Heavy] Add MemberName as callee to Visit
1 parent 5b255ab commit eb21b73

File tree

5 files changed

+84
-32
lines changed

5 files changed

+84
-32
lines changed

heavy/include/nbdl_gen/Nbdl.td

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,28 @@ def Nbdl_Variant : TypeDef<Nbdl_Dialect, "Variant", []> {
9494
*/
9595
}
9696

97-
def Nbdl_Symbol : TypeDef<Nbdl_Dialect, "Symbol"> {
98-
let mnemonic = "symbol";
97+
def Nbdl_MemberName : TypeDef<Nbdl_Dialect, "member_name"> {
98+
let mnemonic = "member_name";
9999
let description = [{
100100
Represent a string that is a valid C++ identifier.
101101
The intended primary use case is to specify a member of a Struct.
102102
}];
103103
}
104104

105-
// Keys are optional, but we are using variadic arguments
106-
// so we resort to using this sum type.
107-
def Nbdl_KeyType : AnyTypeOf<[Nbdl_Store, Nbdl_Tag,
108-
Nbdl_Symbol, Nbdl_Unit]>;
105+
// TODO Eventually add more specific constraints like variant<...>.
106+
def Nbdl_Type : AnyTypeOf<[Nbdl_Store,
107+
Nbdl_Unit,
108+
Nbdl_Tag]>;
109109

110-
def Nbdl_Type : AnyTypeOf<[Nbdl_Unit, Nbdl_Empty,
111-
Nbdl_Store, Nbdl_Tag]>;
110+
// TypePlus is for keys callable objects which allow member names
111+
// used for accessing members or calling member functions.
112+
def Nbdl_TypePlus : AnyTypeOf<[Nbdl_Store,
113+
Nbdl_Unit,
114+
Nbdl_Tag,
115+
Nbdl_MemberName]>;
116+
117+
// Implicitly default to the unit type for keys.
118+
def Nbdl_Key : Optional<Nbdl_TypePlus>;
112119

113120
class Nbdl_Op<string mnemonic, list<Trait> traits = []> :
114121
Op<Nbdl_Dialect, mnemonic, traits>;
@@ -129,8 +136,14 @@ def Nbdl_GetOp : Nbdl_Op<"get", []> {
129136
let description = [{
130137
}];
131138
let arguments = (ins Nbdl_Type:$state,
132-
Nbdl_KeyType:$key);
139+
Nbdl_Key:$key);
133140
let results = (outs Nbdl_Type:$result);
141+
let extraClassDeclaration = [{
142+
bool hasUnitKey() {
143+
return !getKey() ||
144+
::llvm::isa<::nbdl_gen::UnitType>(getKey().getType());
145+
}
146+
}];
134147
}
135148

136149
def Nbdl_MatchOp : Nbdl_Op<"match", [Terminator, NoTerminator]> {
@@ -154,9 +167,15 @@ def Nbdl_MatchOp : Nbdl_Op<"match", [Terminator, NoTerminator]> {
154167
}];
155168

156169
let arguments = (ins Nbdl_Type:$store,
157-
Nbdl_KeyType:$key);
170+
Nbdl_Key:$key);
158171
let results = (outs);
159172
let regions = (region SizedRegion<1>:$overloads);
173+
let extraClassDeclaration = [{
174+
bool hasUnitKey() {
175+
return !getKey() ||
176+
::llvm::isa<::nbdl_gen::UnitType>(getKey().getType());
177+
}
178+
}];
160179
}
161180

162181
def Nbdl_OverloadOp : Nbdl_Op<"overload", []> {
@@ -204,7 +223,7 @@ def Nbdl_MemberNameOp : Nbdl_Op<"member_name", []> {
204223
}];
205224

206225
let arguments = (ins StrAttr:$name);
207-
let results = (outs Nbdl_Symbol:$result);
226+
let results = (outs Nbdl_MemberName:$result);
208227
}
209228

210229
def Nbdl_StoreComposeOp : Nbdl_Op<"store_compose", []> {
@@ -218,10 +237,16 @@ def Nbdl_StoreComposeOp : Nbdl_Op<"store_compose", []> {
218237
}];
219238

220239
let arguments = (ins
221-
Nbdl_KeyType:$key,
240+
Nbdl_Key:$key,
222241
Nbdl_Type:$lhs,
223242
Nbdl_Type:$rhs);
224243
let results = (outs Nbdl_Type:$result);
244+
let extraClassDeclaration = [{
245+
bool hasUnitKey() {
246+
return !getKey() ||
247+
::llvm::isa<::nbdl_gen::UnitType>(getKey().getType());
248+
}
249+
}];
225250
}
226251

227252
def Nbdl_VariantOp : Nbdl_Op<"variant", []> {
@@ -353,7 +378,7 @@ def Nbdl_VisitOp : Nbdl_Op<"visit", []> {
353378
returning a result. This is analogous to std::visit
354379
except only one function is generated.
355380
}];
356-
let arguments = (ins Nbdl_Type:$fn,
381+
let arguments = (ins Nbdl_TypePlus:$fn,
357382
Variadic<Nbdl_Type>:$args);
358383
let results = (outs Nbdl_Type:$result);
359384
}

heavy/lib/Nbdl/NbdlWriter.cpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,9 @@ class NbdlWriter {
320320

321321
void VisitType(StoreComposeOp Op) {
322322
OS << "::nbdl::store_composite<";
323-
VisitType(Op.getKey());
323+
assert(!Op.hasUnitKey() &&
324+
"unit key for store compose currently not supported");
325+
VisitTypeOrUnitType(Op.getKey());
324326
OS << ", ";
325327
VisitType(Op.getLhs());
326328
OS << ", ";
@@ -329,14 +331,19 @@ class NbdlWriter {
329331
}
330332

331333
void VisitType(ConstexprOp Op) {
332-
VisitType(Op.getResult());
334+
VisitType(Op.getLoc(), Op.getResult().getType());
335+
}
336+
337+
void VisitTypeOrUnitType(mlir::Value V) {
338+
assert(V && isa<nbdl_gen::UnitType>(V.getType())
339+
&& "printing unit type currently not supported");
340+
341+
if (!V || !isa<nbdl_gen::UnitType>(V.getType()))
342+
VisitType(V);
333343
}
334344

335345
void VisitType(mlir::Location Loc, mlir::Type Type) {
336-
if (auto OpaqueType = dyn_cast<nbdl_gen::OpaqueType>(Type))
337-
OS << "decltype(auto)";
338-
else
339-
SetError(Loc, "unsupported type");
346+
OS << "auto&&";
340347
}
341348
};
342349

@@ -412,7 +419,11 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
412419
}
413420

414421
void Visit(GetOp Op) {
415-
auto MemberNameOp = Op.getKey().getDefiningOp<nbdl_gen::MemberNameOp>();
422+
bool HasUnitKey = Op.hasUnitKey();
423+
nbdl_gen::MemberNameOp MemberNameOp;
424+
if (!HasUnitKey)
425+
MemberNameOp = Op.getKey().getDefiningOp<nbdl_gen::MemberNameOp>();
426+
416427
OS << "auto&& "
417428
<< SetLocalVarName(Op.getResult(), "get_")
418429
<< " = ";
@@ -423,7 +434,7 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
423434
} else {
424435
OS << "::nbdl::get(";
425436
WriteExpr(Op.getState());
426-
if (!isa<nbdl_gen::UnitType>(Op.getKey().getType())) {
437+
if (!HasUnitKey) {
427438
OS << ", ";
428439
WriteExpr(Op.getKey());
429440
}
@@ -440,7 +451,20 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
440451
<< " = ";
441452
}
442453

443-
WriteExpr(Op.getFn());
454+
mlir::Value Fn = Op.getFn();
455+
mlir::OperandRange Args = Op.getArgs();
456+
457+
if (auto MemberNameOp = Fn.getDefiningOp<nbdl_gen::MemberNameOp>()) {
458+
if (Args.empty()) {
459+
SetError("member literal callee expects at least one argument", Op);
460+
return;
461+
}
462+
WriteExpr(Args.front());
463+
OS << '.' << MemberNameOp.getName();
464+
Args = Args.drop_front();
465+
} else {
466+
WriteExpr(Op.getFn());
467+
}
444468
OS << '(';
445469
llvm::interleave(Op.getArgs(), OS,
446470
[&](mlir::Value V) {
@@ -458,8 +482,10 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
458482
OS << "auto ";
459483
OS << SetLocalVarName(Op.getResult(), "result_");
460484
OS << " = ::nbdl::store_compose(";
461-
WriteExpr(Op.getKey());
462-
OS << ",";
485+
if (!Op.hasUnitKey()) {
486+
WriteExpr(Op.getKey());
487+
OS << ",";
488+
}
463489
WriteExpr(Op.getRhs());
464490
OS << ",";
465491
WriteExpr(Op.getLhs());
@@ -475,7 +501,7 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
475501
void Visit(MatchOp Op) {
476502
OS << "::nbdl::match(";
477503
WriteExpr(Op.getStore());
478-
if (!isa<nbdl_gen::UnitType>(Op.getKey().getType())) {
504+
if (!Op.hasUnitKey()) {
479505
OS << ", ";
480506
WriteExpr(Op.getKey());
481507
}
@@ -659,11 +685,12 @@ class DefineStoreWriter : public NbdlWriter<DefineStoreWriter> {
659685
heavy::SourceLocation PrevLoc = CurLoc;
660686
SetLoc(Op.getLoc());
661687

662-
mlir::Value Key = Op.getKey();
688+
mlir::Value Key = Op.getKey(); // Could be mlir::Value().
663689
mlir::Value Lhs = Op.getLhs();
664690
llvm::StringRef Name;
665691

666-
if (auto MemberNameOp = Key.getDefiningOp<nbdl_gen::MemberNameOp>()) {
692+
if (Key && Key.getDefiningOp<nbdl_gen::MemberNameOp>()) {
693+
auto MemberNameOp = Key.getDefiningOp<nbdl_gen::MemberNameOp>();
667694
// It would be more consistent with our definition of StoreCompose
668695
// to support shadowing here, but since it is more work to check
669696
// and very suboptimal for the C++ compiler

heavy/test/Nbdl/close_previous_scope.scm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
(define !nbdl.store (type "!nbdl.store"))
1111
(define !nbdl.unit (type "!nbdl.unit"))
12-
(define !nbdl.symbol (type "!nbdl.symbol"))
12+
(define !nbdl.member_name (type "!nbdl.member_name"))
1313
(define !nbdl.unit (type "!nbdl.unit"))
1414
(define i32 (type "i32"))
1515

heavy/test/Nbdl/context.scm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
(define !nbdl.store (type "!nbdl.store"))
1111
(define !nbdl.tag (type "!nbdl.tag"))
12-
(define !nbdl.symbol (type "!nbdl.symbol"))
12+
(define !nbdl.member_name (type "!nbdl.member_name"))
1313
(define !nbdl.unit (type "!nbdl.unit"))
1414
(define i32 (type "i32"))
1515

@@ -20,7 +20,7 @@
2020
(operands:)
2121
(attributes:
2222
("name" (string-attr name)))
23-
(result-types: !nbdl.symbol))))
23+
(result-types: !nbdl.member_name))))
2424

2525
(define my_store
2626
(with-builder

heavy/test/Nbdl/match_params.scm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
(define !nbdl.store (type "!nbdl.store"))
1111
(define !nbdl.tag (type "!nbdl.tag"))
12-
(define !nbdl.symbol (type "!nbdl.symbol"))
12+
(define !nbdl.member_name (type "!nbdl.member_name"))
1313
(define !nbdl.unit (type "!nbdl.unit"))
1414
(define i32 (type "i32"))
1515

@@ -43,7 +43,7 @@
4343
(old-create-op "nbdl.member_name"
4444
(attributes
4545
`("name", (string-attr "bar")))
46-
(result-types !nbdl.symbol))))
46+
(result-types !nbdl.member_name))))
4747
(define key3
4848
(result
4949
(old-create-op "nbdl.literal"

0 commit comments

Comments
 (0)