Skip to content

Commit 2fef733

Browse files
committed
[Heavy] Fix mlir builder insertion points; Add various attribute types
1 parent 7a48bdf commit 2fef733

File tree

2 files changed

+112
-60
lines changed

2 files changed

+112
-60
lines changed

heavy/lib/Mlir.cpp

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ heavy::ContextLocal current_builder;
3939
heavy::ExternSyntax<> create_op;
4040
heavy::ExternFunction create_op_impl;
4141
heavy::ExternFunction region;
42+
heavy::ExternFunction entry_block;
4243
heavy::ExternFunction results;
4344
heavy::ExternFunction result;
4445
heavy::ExternFunction at_block_begin;
@@ -49,6 +50,9 @@ heavy::ExternFunction set_insertion_point;
4950
heavy::ExternFunction set_insertion_after;
5051
heavy::ExternFunction type;
5152
heavy::ExternFunction attr;
53+
heavy::ExternFunction value_attr;
54+
template <typename AttrTy>
55+
heavy::ExternFunction string_attr;
5256
heavy::ExternFunction with_new_context;
5357
heavy::ExternFunction with_builder;
5458
heavy::ExternFunction load_dialect;
@@ -222,7 +226,7 @@ void create_op_impl(Context& C, ValueRefs Args) {
222226
// (create-op _name_
223227
// (attributes _attrs_ ...)
224228
// (operands _values_ ...)
225-
// (regions _regions_ ...)
229+
// (regions _regions_ )
226230
// (result-types _types_ ...)
227231
// (successors _blocks_ ...))
228232
//
@@ -260,9 +264,11 @@ void create_op(Context& C, ValueRefs Args) { // Syntax
260264
else if (ArgName == "operands")
261265
Operands = Arg;
262266
else if (ArgName == "regions") {
263-
if (!isa<heavy::Int>(Arg))
264-
return C.RaiseError("expecting single argument");
265-
NumRegions = cast<heavy::Int>(Arg);
267+
// Support improper list.
268+
if (auto* PArg = dyn_cast<Pair>(Arg))
269+
NumRegions = PArg->Car;
270+
else
271+
NumRegions = Arg;
266272
}
267273
else if (ArgName == "result-types")
268274
ResultTypes = Arg;
@@ -343,6 +349,7 @@ void region(Context& C, ValueRefs Args) {
343349

344350
// Get entry block from region/op by index.
345351
// If an op is provided the first region is used.
352+
// Add block to region if empty.
346353
void entry_block(Context& C, ValueRefs Args) {
347354
if (Args.size() != 1)
348355
return C.RaiseError("invalid arity");
@@ -359,7 +366,7 @@ void entry_block(Context& C, ValueRefs Args) {
359366
if (!Region)
360367
return C.RaiseError("expecting mlir.op/mlir.region");
361368
if (Region->empty())
362-
return C.RaiseError("mlir.region has no entry block");
369+
Region->emplaceBlock();
363370
mlir::Block* Block = &(Region->front());
364371
if (!Block)
365372
return C.RaiseError("invalid mlir.block");
@@ -448,14 +455,26 @@ static void with_builder_impl(Context& C, mlir::OpBuilder const& Builder,
448455
C.DynamicWind(Before, Thunk, After);
449456
}
450457

451-
// (with-builder _builder_ _thunk_)
458+
// Copy the builder. (ie Do not modify.)
459+
// (with-builder [_builder_] _thunk_)
452460
void with_builder(Context& C, ValueRefs Args) {
453-
if (Args.size() != 2)
461+
if (Args.empty() || Args.size() > 2)
454462
return C.RaiseError("expecting 2 arguments");
455-
mlir::OpBuilder* Builder = getBuilder(C, Args[0]);
456-
if (!Builder)
457-
return;
458-
return with_builder_impl(C, *Builder, Args[1]);
463+
464+
heavy::Value Thunk;
465+
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
466+
mlir::OpBuilder Builder(MLIRContext);
467+
if (Args.size() == 2) {
468+
mlir::OpBuilder* BuilderPtr = getBuilder(C, Args[0]);
469+
if (!BuilderPtr)
470+
return;
471+
Thunk = Args[1];
472+
Builder = *BuilderPtr;
473+
} else {
474+
Thunk = Args[0];
475+
}
476+
477+
return with_builder_impl(C, Builder, Thunk);
459478
}
460479

461480
static mlir::Block* get_arg_block(Context& C, ValueRefs Args) {
@@ -473,7 +492,7 @@ void at_block_begin(Context& C, ValueRefs Args) {
473492
if (mlir::Block* Block = get_arg_block(C, Args)) {
474493
mlir::OpBuilder* Builder = getCurrentBuilder(C);
475494
if (!Builder) return;
476-
Builder->atBlockBegin(Block);
495+
*Builder = mlir::OpBuilder::atBlockBegin(Block);
477496
C.Cont();
478497
}
479498
// Note: error raised in get_arg_block
@@ -485,7 +504,7 @@ void at_block_end(Context& C, ValueRefs Args) {
485504
if (mlir::Block* Block = get_arg_block(C, Args)) {
486505
mlir::OpBuilder* Builder = getCurrentBuilder(C);
487506
if (!Builder) return;
488-
Builder->atBlockEnd(Block);
507+
*Builder = mlir::OpBuilder::atBlockEnd(Block);
489508
C.Cont();
490509
}
491510
// Note: error raised in get_arg_block
@@ -497,7 +516,7 @@ void at_block_terminator(Context& C, ValueRefs Args) {
497516
if (mlir::Block* Block = get_arg_block(C, Args)) {
498517
mlir::OpBuilder* Builder = getCurrentBuilder(C);
499518
if (!Builder) return;
500-
Builder->atBlockTerminator(Block);
519+
*Builder = mlir::OpBuilder::atBlockTerminator(Block);
501520
C.Cont();
502521
}
503522
// Note: error raised in get_arg_block
@@ -561,20 +580,21 @@ void type(Context& C, ValueRefs Args) {
561580
}
562581

563582
// Get an attribute by parsing a string.
564-
// Usage: (attr _type_ _attr_str)
565-
// type - a string or a mlir.type object
583+
// Usage: (attr _attr_str [_type_])
566584
// attr_str - the string to be parsed
567-
// Usage: (attr _val_)
568-
// val - The scheme value to convert to !heavy.value
585+
// type - a string or a mlir.type object (defaults to NoneType)
569586
void attr(Context& C, ValueRefs Args) {
570587
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
571588
mlir::Attribute Attr;
589+
if (Args.size() > 2 || Args.empty())
590+
return C.RaiseError("invalid arity");
591+
592+
heavy::Value AttrStrArg = Args[0];
572593

594+
mlir::Type Type;
573595
if (Args.size() == 2) {
574-
heavy::Value TypeArg = Args[0];
575-
heavy::Value AttrStrArg = Args[1];
596+
heavy::Value TypeArg = Args[1];
576597
llvm::StringRef TypeStr = TypeArg.getStringRef();
577-
mlir::Type Type;
578598
if (!TypeStr.empty()) {
579599
Type = mlir::parseType(TypeStr, MLIRContext, nullptr,
580600
heavy::String::IsNullTerminated);
@@ -586,22 +606,41 @@ void attr(Context& C, ValueRefs Args) {
586606
if (!Type)
587607
return C.RaiseError("invalid mlir type");
588608
}
609+
}
610+
if (!Type)
611+
Type = mlir::NoneType::get(MLIRContext);
612+
613+
llvm::StringRef AttrStr = AttrStrArg.getStringRef();
614+
if (AttrStr.empty())
615+
return C.RaiseError("expecting string");
589616

590-
llvm::StringRef AttrStr = AttrStrArg.getStringRef();
591-
if (AttrStr.empty())
592-
return C.RaiseError("expecting string");
617+
Attr = mlir::parseAttribute(AttrStr, MLIRContext,
618+
Type, nullptr,
619+
heavy::String::IsNullTerminated);
620+
if (!Attr)
621+
return C.RaiseError("mlir attribute parse failed");
593622

594-
Attr = mlir::parseAttribute(AttrStr, MLIRContext,
595-
Type, nullptr,
596-
heavy::String::IsNullTerminated);
597-
if (!Attr)
598-
return C.RaiseError("mlir attribute parse failed");
599-
} else if (Args.size() == 1) {
600-
Attr = HeavyValueAttr::get(MLIRContext, Args[0]);
601-
}
602-
else
623+
C.Cont(CreateTagged(C, kind::mlir_attr, Attr.getImpl()));
624+
}
625+
626+
// Create a heavy scheme value attribute of type !heavy.value
627+
void value_attr(Context& C, ValueRefs Args) {
628+
if (Args.size() != 1)
603629
return C.RaiseError("invalid arity");
630+
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
631+
mlir::Attribute Attr = HeavyValueAttr::get(MLIRContext, Args[0]);
632+
C.Cont(CreateTagged(C, kind::mlir_attr, Attr.getImpl()));
633+
}
604634

635+
template <typename AttrTy>
636+
void string_attr(Context& C, ValueRefs Args) {
637+
if (Args.size() != 1)
638+
return C.RaiseError("invalid arity");
639+
if (!isa<heavy::String, heavy::Symbol>(Args[0]))
640+
return C.RaiseError("expecting string-like object");
641+
llvm::StringRef Str = Args[0].getStringRef();
642+
mlir::MLIRContext* MLIRContext = getCurrentContext(C);
643+
mlir::Attribute Attr = AttrTy::get(MLIRContext, Str);
605644
C.Cont(CreateTagged(C, kind::mlir_attr, Attr.getImpl()));
606645
}
607646

@@ -750,6 +789,7 @@ void HEAVY_MLIR_INIT(heavy::Context& C) {
750789
HEAVY_MLIR_VAR(create_op) = heavy::mlir_bind::create_op;
751790
HEAVY_MLIR_VAR(create_op_impl) = heavy::mlir_bind::create_op_impl;
752791
HEAVY_MLIR_VAR(region) = heavy::mlir_bind::region;
792+
HEAVY_MLIR_VAR(entry_block) = heavy::mlir_bind::entry_block;
753793
HEAVY_MLIR_VAR(results) = heavy::mlir_bind::results;
754794
HEAVY_MLIR_VAR(result) = heavy::mlir_bind::result;
755795
HEAVY_MLIR_VAR(at_block_begin) = heavy::mlir_bind::at_block_begin;
@@ -761,6 +801,11 @@ void HEAVY_MLIR_INIT(heavy::Context& C) {
761801
HEAVY_MLIR_VAR(set_insertion_after) = heavy::mlir_bind::set_insertion_after;
762802
HEAVY_MLIR_VAR(type) = heavy::mlir_bind::type;
763803
HEAVY_MLIR_VAR(attr) = heavy::mlir_bind::attr;
804+
HEAVY_MLIR_VAR(value_attr) = heavy::mlir_bind::value_attr;
805+
HEAVY_MLIR_VAR(string_attr<mlir::StringAttr>)
806+
= heavy::mlir_bind::string_attr<mlir::StringAttr>;
807+
HEAVY_MLIR_VAR(string_attr<mlir::FlatSymbolRefAttr>)
808+
= heavy::mlir_bind::string_attr<mlir::FlatSymbolRefAttr>;
764809
HEAVY_MLIR_VAR(load_dialect) = heavy::mlir_bind::load_dialect;
765810
HEAVY_MLIR_VAR(with_builder) = heavy::mlir_bind::with_builder;
766811
HEAVY_MLIR_VAR(with_new_context) = heavy::mlir_bind::with_new_context;
@@ -773,6 +818,7 @@ void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
773818
{"create-op", HEAVY_MLIR_VAR(create_op)},
774819
{"%create-op", HEAVY_MLIR_VAR(create_op_impl)},
775820
{"region", HEAVY_MLIR_VAR(region)},
821+
{"entry-block", HEAVY_MLIR_VAR(entry_block)},
776822
{"results", HEAVY_MLIR_VAR(results)},
777823
{"result", HEAVY_MLIR_VAR(result)},
778824
{"at-block-begin", HEAVY_MLIR_VAR(at_block_begin)},
@@ -784,6 +830,10 @@ void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
784830
{"set-insertion-after", HEAVY_MLIR_VAR(set_insertion_after)},
785831
{"type", HEAVY_MLIR_VAR(type)},
786832
{"attr", HEAVY_MLIR_VAR(attr)},
833+
{"value-attr", HEAVY_MLIR_VAR(value_attr)},
834+
{"string-attr", HEAVY_MLIR_VAR(string_attr<mlir::StringAttr>)},
835+
{"flat-symbolref-attr",
836+
HEAVY_MLIR_VAR(string_attr<mlir::FlatSymbolRefAttr>)},
787837
{"with-new-context", HEAVY_MLIR_VAR(with_new_context)},
788838
{"with-builder", HEAVY_MLIR_VAR(with_builder)},
789839
{"load-dialect", HEAVY_MLIR_VAR(load_dialect)},

heavy/test/Evaluate/create-op.scm

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@
77
; CHECK: #op{"heavy.literal"() {info = #heavy<"\22foo\22">} : () -> ()
88
(write (create-op "heavy.literal"
99
(attributes
10-
`("info", (attr "foo")))))
10+
`("info", (value-attr "foo")))))
1111

1212
(newline)
1313

1414
; CHECK: #op{"heavy.literal"() {info = #heavy<"5">} : () -> ()
1515
(write (create-op "heavy.literal"
1616
(attributes
17-
`("info", (attr 5)))))
17+
`("info", (value-attr 5)))))
1818

1919
(newline)
2020

2121
; CHECK: #op{"heavy.literal"() {info = #heavy<"5000">} : () -> ()
2222
(write (create-op "heavy.literal"
2323
(attributes
24-
`("info", (attr !heavy.value "#heavy<\"5000\">")))))
24+
`("info", (attr "#heavy<\"5000\">" !heavy.value)))))
2525

2626
(define the-answer
2727
(create-op "heavy.literal"
2828
(attributes
29-
`("info", (attr 41)))))
29+
`("info", (value-attr 41)))))
3030

3131
(newline)
3232
; CHECK: non-parent:()
@@ -35,28 +35,30 @@
3535

3636
(newline)
3737

38-
; CHECK "heavy.command"()({
39-
; CHECK-NEXT %1 = "heavy.load_global"() <{name = @_HEAVYL5SheavyL4SbaseV5Swrite
40-
; CHECK-NEXT %2 = "heavy.literal"() <{input = #heavy<"42">
41-
; CHECK-NEXT "heavy.apply"(%1, %2) : (!heavy.value, !heavy.value) -> ()
42-
; CHECK-NEXT }) : () -> ()
43-
(define command
44-
(create-op "heavy.command"))
45-
46-
(set-insertion-point command)
38+
; CHECK: "heavy.command"() ({
39+
; CHECK-NEXT: %0 = "heavy.load_global"() <{name = @_HEAVYL5SheavyL4SbaseV5Swrite
40+
; CHECK-NEXT: %1 = "heavy.literal"() {info = #heavy<"42">}
41+
; CHECK-NEXT: "heavy.apply"(%0, %1) : (!heavy.value, !heavy.value) -> ()
42+
; CHECK-NEXT: }) : () -> ()
43+
(define the-number-one 1)
4744

48-
(define callee
49-
(create-op "heavy.load_global"
50-
(result-types !heavy.value)
51-
(attributes
52-
`("name", (attr "_HEAVYL5SheavyL4SbaseV5Swrite")))))
53-
54-
(define arg1
55-
(create-op "heavy.literal"
56-
(result-types !heavy.value)
57-
(attributes
58-
`("info", (attr 42)))))
59-
60-
(create-op "heavy.apply"
61-
(operands (result callee) (result arg1)))
45+
(define command
46+
(create-op "heavy.command" (regions the-number-one)))
47+
(with-builder (lambda ()
48+
(at-block-begin (entry-block command))
49+
((lambda ()
50+
(define callee
51+
(create-op "heavy.load_global"
52+
(result-types !heavy.value)
53+
(attributes
54+
`("name", (flat-symbolref-attr "_HEAVYL5SheavyL4SbaseV5Swrite")))))
55+
(define arg1
56+
(create-op "heavy.literal"
57+
(result-types !heavy.value)
58+
(attributes
59+
`("info", (value-attr 42)))))
60+
(create-op "heavy.apply"
61+
(operands (result callee) (result arg1)))))))
62+
(write command)
63+
(newline)
6264

0 commit comments

Comments
 (0)