Skip to content

Commit e60b23c

Browse files
committed
[Heavy] Support lists in create-op operands; Clean up expr print in NbdlWriter; Add more list tests
1 parent 4c9a9fb commit e60b23c

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

heavy/lib/Mlir.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ void create_op_impl(Context& C, ValueRefs Args) {
136136
return C.RaiseError("expecting mlir.value", V2);
137137
OpState.operands.push_back(MVal);
138138
}
139+
break;
139140
}
140141
auto MVal = getTagged<mlir::Value>(C, kind::mlir_value, V);
141142
if (!MVal)

heavy/lib/Nbdl/NbdlWriter.cpp

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,50 @@ class NbdlWriter {
124124
return SetLocalVarName(V, "anon_");
125125
}
126126

127-
void WriteForwardedExpr(mlir::Value V) {
128-
llvm::StringRef Expr = GetLocalVal(V);
127+
/************************************
128+
*********** Expr Printing **********
129+
************************************/
129130

131+
void WriteExpr(mlir::Value V, bool IsFwd = false) {
130132
// We do not need to forward literals and junk.
131-
if (mlir::Operation* Op = V.getDefiningOp()) {
132-
if (isa<LiteralOp, ConstexprOp>(Op)) {
133+
if (auto Op = V.getDefiningOp<LiteralOp>()) {
134+
WriteExpr(Op);
135+
} else if (auto Op = V.getDefiningOp<ConstexprOp>()) {
136+
WriteExpr(Op);
137+
} else {
138+
llvm::StringRef Expr = GetLocalVal(V);
139+
if (IsFwd) {
140+
OS << "static_cast<decltype(" << Expr << ")>("
141+
<< Expr
142+
<< ")";
143+
} else {
133144
OS << Expr;
134-
return;
135145
}
136146
}
147+
}
148+
149+
void WriteForwardedExpr(mlir::Value V) {
150+
WriteExpr(V, /*IsFwd=*/true);
151+
}
152+
153+
void WriteExpr(ConstexprOp Op) {
154+
llvm::StringRef Expr = Op.getExpr();
155+
if (Expr.empty())
156+
SetError("expecting expr", Op);
157+
OS << Expr;
158+
}
137159

138-
OS << "static_cast<decltype(" << Expr << ")>("
139-
<< Expr
140-
<< ")";
160+
void WriteExpr(LiteralOp Op) {
161+
mlir::Attribute Attr = Op.getValue();
162+
if (auto IA = dyn_cast<mlir::IntegerAttr>(Attr);
163+
IA &&
164+
(IA.getType().isIndex() || IA.getType().isSignlessInteger())) {
165+
OS << IA.getInt();
166+
} else if (auto SA = dyn_cast<mlir::StringAttr>(Attr)) {
167+
OS << llvm::StringRef(SA);
168+
} else {
169+
SetError("unknown literal type", Op);
170+
}
141171
}
142172

143173
/************************************
@@ -272,13 +302,16 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
272302
}
273303

274304
void Visit(ConstexprOp Op) {
305+
#if 0
275306
llvm::StringRef Expr = Op.getExpr();
276307
if (Expr.empty())
277308
SetError("expecting expr", Op);
278309
SetLocalVal(Op.getResult(), llvm::Twine(Expr));
310+
#endif
279311
}
280312

281313
void Visit(LiteralOp Op) {
314+
#if 0
282315
mlir::Attribute Attr = Op.getValue();
283316
if (auto IA = dyn_cast<mlir::IntegerAttr>(Attr);
284317
IA &&
@@ -289,6 +322,7 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
289322
} else {
290323
SetError("unknown literal type", Op);
291324
}
325+
#endif
292326
}
293327

294328
void Visit(GetOp Op) {
@@ -349,8 +383,11 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
349383
void Visit(MatchIfOp Op) {
350384
mlir::Region& Then = Op.getThenRegion();
351385
mlir::Region& Else = Op.getElseRegion();
352-
OS << "if (" << GetLocalVal(Op.getPred()) << '('
353-
<< GetLocalVal(Op.getInput()) << ")) {\n";
386+
OS << "if (";
387+
WriteExpr(Op.getPred());
388+
OS << '(';
389+
WriteExpr(Op.getInput());
390+
OS << ")) {\n";
354391
VisitRegion(Then);
355392

356393
// Check if the the else region is a single MatchIfOp
@@ -370,11 +407,11 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
370407
<< SetLocalVarName(Op.getResult(), "apply_")
371408
<< " = ";
372409
// No forwarding stuff here
373-
OS << GetLocalVal(Op.getFn());
410+
WriteExpr(Op.getFn());
374411
OS << '(';
375412
llvm::interleaveComma(Op.getArgs(), OS,
376413
[&](mlir::Value Val) {
377-
OS << GetLocalVal(Val);
414+
WriteExpr(Val);
378415
});
379416
OS << ");\n";
380417
}

heavy/test/Evaluate/list.scm

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
1))
5454
(newline)
5555

56+
; CHECK-NEXT: #(1 2 ())
57+
(write
58+
((lambda (arg1 arg2 . args)
59+
#(arg1 arg2 ()))
60+
1 2))
61+
(newline)
62+
5663
; CHECK-NEXT: #(1 (2))
5764
(write
5865
((lambda (arg . args)

0 commit comments

Comments
 (0)