Skip to content

Commit f3969c6

Browse files
committed
[Heavy] Add new create-op syntax; Fix nested packs
1 parent 4a5180e commit f3969c6

File tree

18 files changed

+367
-150
lines changed

18 files changed

+367
-150
lines changed

heavy/include/heavy/Mlir.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
#ifndef LLVM_HEAVY_MLIR_H
1414
#define LLVM_HEAVY_MLIR_H
1515

16-
#define HEAVY_MLIR_LIB _HEAVYL5SheavyL4Smlir
17-
#define HEAVY_MLIR_LIB_(NAME) _HEAVYL5SheavyL4Smlir ## NAME
18-
#define HEAVY_MLIR_LIB_STR "_HEAVYL5SheavyL4Smlir"
16+
#define HEAVY_MLIR_LIB _HEAVYL5SheavyL4SmlirL8Sbuiltins
17+
#define HEAVY_MLIR_LIB_(NAME) _HEAVYL5SheavyL4SmlirL8Sbuiltins ## NAME
18+
#define HEAVY_MLIR_LIB_STR "_HEAVYL5SheavyL4SmlirL8Sbuiltins"
1919
#define HEAVY_MLIR_LOAD_MODULE HEAVY_MLIR_LIB_(_load_module)
2020
#define HEAVY_MLIR_INIT HEAVY_MLIR_LIB_(_init)
2121

heavy/include/heavy/OpGen.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
241241
// An error may occur here, but we always want
242242
// this function to return a valid operation.
243243
// Just continue until the error is realized.
244-
FinishLocalDefines();
244+
if (!CheckError())
245+
FinishLocalDefines();
245246
}
246247
return createHelper<Op>(Builder, Loc, std::forward<Args>(args)...);
247248
}
@@ -345,14 +346,23 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
345346
}
346347

347348
template <typename T>
348-
mlir::Value SetError(SourceLocation Loc, T Str, Value V = Undefined()) {
349-
heavy::Error* E = Context.CreateError(Loc, Str, Context.CreatePair(V));
349+
mlir::Value SetError(SourceLocation Loc, T Str,
350+
llvm::ArrayRef<Value> IrrArgs = {}) {
351+
heavy::Error* E = Context.CreateError(Loc, Str, Context.CreateList(IrrArgs));
350352
return SetError(E);
351353
}
352354

353355
template <typename T>
354-
mlir::Value SetError(T Str, Value V = Undefined()) {
355-
return SetError(V.getSourceLocation(), Str, V);
356+
mlir::Value SetError(T Str, llvm::ArrayRef<Value> IrrArgs = {}) {
357+
heavy::SourceLocation Loc;
358+
for (Value Irr : llvm::reverse(IrrArgs))
359+
Loc = Loc.isValid() ? Loc : Irr.getSourceLocation();
360+
return SetError(Loc, Str, IrrArgs);
361+
}
362+
363+
template <typename T>
364+
mlir::Value SetError(T Str, Value V) {
365+
return SetError(Str, llvm::ArrayRef<Value>(V));
356366
}
357367

358368
mlir::Value Error() {

heavy/include/heavy/base.sld

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
*
3535
< <= > >=
3636
positive? zero?
37+
range
3738
apply
3839
append
3940
call-with-values
@@ -64,6 +65,7 @@
6465
let letrec letrec*
6566
cond case
6667
and or when unless
68+
do
6769

6870
; (heavy base list)
6971
caar cadr cdar cddr

heavy/include/heavy/base/int.sld

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
(import (heavy builtins)
55
(heavy base r7rs-syntax))
66
(begin
7+
(define range ; Copied from R7RS
8+
(case-lambda
9+
((e) (range 0 e))
10+
((b e) (do ((r '() (cons e r))
11+
(e (- e 1) (- e 1)))
12+
((< e b) r)))))
13+
714
(define (< x1 x2 . xN)
815
(if (positive? (- x2 x1))
916
(if (pair? xN)
@@ -37,4 +44,5 @@
3744
) ; end begin
3845
(export
3946
< <= > >=
47+
range
4048
))

heavy/include/heavy/base/r7rs-syntax.sld

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,29 @@
2525
tag)
2626
val ...))))
2727

28+
(define-syntax do
29+
(syntax-rules ()
30+
((do ((var init step ...) ...)
31+
(test expr ...)
32+
command ...)
33+
(letrec
34+
((loop
35+
(lambda (var ...)
36+
(if test
37+
(begin
38+
(if #f #f)
39+
expr ...)
40+
(begin
41+
command
42+
...
43+
(loop (do "step" var step ...)
44+
...))))))
45+
(loop init ...)))
46+
((do "step" x)
47+
x)
48+
((do "step" x y)
49+
y)))
50+
2851
; FIXME cond should be in the environment within syntax body.
2952
; (this applies to all define-syntax)
3053
(define-syntax cond
@@ -175,4 +198,5 @@
175198
let letrec letrec*
176199
cond case
177200
and or when unless
201+
do
178202
))

heavy/include/heavy/mlir.sld

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
(import (heavy builtins))
2+
3+
(define-library (heavy mlir)
4+
(import (heavy base)
5+
(heavy mlir builtins))
6+
(begin
7+
(define (init-regions Op BlockArgTypesList UserFns)
8+
(define Is (range (length BlockArgTypesList)))
9+
(define (InitRegion RegionIndex BlockArgTypes UserFn)
10+
(with-builder
11+
(lambda ()
12+
(define Region (get-region Op RegionIndex))
13+
(define Block (entry-block Region))
14+
(define Args
15+
(let ()
16+
(define (Proc BlockArgType)
17+
(add-argument Block BlockArgType Loc))
18+
(map Proc BlockArgTypes)))
19+
(at-block-begin Block)
20+
(apply UserFn Args))))
21+
(map InitRegion Is BlockArgTypesList UserFns))
22+
23+
(define-syntax create-op
24+
(syntax-rules (: loc attributes operands result-types region)
25+
((create-op Name
26+
(loc Loc)
27+
(operands Operands ...)
28+
(attributes (AttrName Attr) ...)
29+
(result-types ResultTypes ...)
30+
(region RegionName ((BlockArg : BlockArgType) ...)
31+
RegionBody ...) ...)
32+
(let ((Op
33+
(old-create-op Name
34+
(loc Loc)
35+
(operands Operands ...)
36+
(attributes (list 'AttrName Attr) ...)
37+
(result-types ResultTypes ...)
38+
(regions (length '(RegionName ...)))
39+
))
40+
(BlockArgsTypesList (list (list BlockArgType ...) ...))
41+
(UserFns (list (lambda (BlockArg ...)
42+
RegionBody ...) ...)))
43+
(init-regions Op BlockArgsTypesList UserFns)
44+
Op)
45+
)))
46+
) ; end of begin
47+
48+
(export
49+
create-op
50+
; (heavy mlir builtins)
51+
old-create-op
52+
current-builder
53+
get-region
54+
entry-block
55+
add-argument
56+
results
57+
result
58+
at-block-begin
59+
at-block-end
60+
block-op
61+
op-next
62+
parent-op
63+
set-insertion-point
64+
set-insertion-after
65+
type
66+
%function-type
67+
attr
68+
type-attr
69+
value-attr
70+
string-attr
71+
flat-symbolref-attr
72+
with-new-context
73+
with-builder
74+
load-dialect
75+
verify
76+
module-lookup
77+
value?
78+
)
79+
)
80+

heavy/lib/Mlir.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ heavy::ContextLocal current_context;
3838
heavy::ContextLocal current_builder;
3939
heavy::ExternSyntax<> create_op;
4040
heavy::ExternFunction create_op_impl;
41-
heavy::ExternFunction region;
41+
heavy::ExternFunction get_region;
4242
heavy::ExternFunction entry_block;
4343
heavy::ExternFunction add_argument;
4444
heavy::ExternFunction results;
@@ -296,19 +296,19 @@ void create_op(Context& C, ValueRefs Args) { // Syntax
296296
C.Cont(OpGen.fromValue(Call));
297297
}
298298

299-
// Get an operation region by index (defaulting to 0).
300-
// (region _op_)
301-
// (region op _index_)
302-
void region(Context& C, ValueRefs Args) {
299+
// Get an operation get_region by index (defaulting to 0).
300+
// (get-region _op_)
301+
// (get-region op _index_)
302+
void get_region(Context& C, ValueRefs Args) {
303303
if (Args.size() != 1 && Args.size() != 2)
304304
return C.RaiseError("invalid arity");
305305

306-
mlir::Operation* Op = heavy::dyn_cast<mlir::Operation>(Args[1]);
306+
mlir::Operation* Op = heavy::dyn_cast<mlir::Operation>(Args[0]);
307307
if (!Op)
308-
return C.RaiseError("expecting mlir.op");
308+
return C.RaiseError("expecting mlir.op: {}", {Args[0]});
309309

310310
if (Args.size() > 1 && !heavy::isa<heavy::Int>(Args[1]))
311-
return C.RaiseError("expecting index");
311+
return C.RaiseError("expecting index: {}", {Args[1]});
312312

313313
int32_t Index = heavy::isa<heavy::Int>(Args[1]) ?
314314
int32_t{heavy::cast<heavy::Int>(Args[1])} : 0;
@@ -843,7 +843,7 @@ void HEAVY_MLIR_INIT(heavy::Context& C) {
843843

844844
HEAVY_MLIR_VAR(create_op) = heavy::mlir_bind::create_op;
845845
HEAVY_MLIR_VAR(create_op_impl) = heavy::mlir_bind::create_op_impl;
846-
HEAVY_MLIR_VAR(region) = heavy::mlir_bind::region;
846+
HEAVY_MLIR_VAR(get_region) = heavy::mlir_bind::get_region;
847847
HEAVY_MLIR_VAR(entry_block) = heavy::mlir_bind::entry_block;
848848
HEAVY_MLIR_VAR(add_argument) = heavy::mlir_bind::add_argument;
849849
HEAVY_MLIR_VAR(results) = heavy::mlir_bind::results;
@@ -875,10 +875,10 @@ void HEAVY_MLIR_INIT(heavy::Context& C) {
875875
void HEAVY_MLIR_LOAD_MODULE(heavy::Context& C) {
876876
HEAVY_MLIR_INIT(C);
877877
heavy::initModuleNames(C, HEAVY_MLIR_LIB_STR, {
878-
{"create-op", HEAVY_MLIR_VAR(create_op)},
878+
{"old-create-op", HEAVY_MLIR_VAR(create_op)},
879879
{"%create-op", HEAVY_MLIR_VAR(create_op_impl)},
880880
{"current-builder", HEAVY_MLIR_VAR(current_builder).get_binding(C)},
881-
{"region", HEAVY_MLIR_VAR(region)},
881+
{"get-region", HEAVY_MLIR_VAR(get_region)},
882882
{"entry-block", HEAVY_MLIR_VAR(entry_block)},
883883
{"add-argument", HEAVY_MLIR_VAR(add_argument)},
884884
{"results", HEAVY_MLIR_VAR(results)},

heavy/lib/OpEval.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ class OpEvalImpl {
610610
BlockItrTy patternFail(mlir::Operation* Op, llvm::StringRef ErrMsg,
611611
llvm::ArrayRef<heavy::Value> Irr) {
612612
assert((isa<MatchOp, MatchPairOp, MatchTailOp, MatchArgsOp,
613-
MatchIdOp>(Op)) &&
613+
MatchIdOp, SubpatternOp>(Op)) &&
614614
"Operation must be a pattern matcher");
615615

616616
mlir::Operation* ParentOp = Op->getParentOp();
@@ -623,8 +623,9 @@ class OpEvalImpl {
623623
return BlockItrTy();
624624
}
625625

626+
// Use the SubPatternOp itself as the sentinel for pattern failure.
626627
if (auto SubpatternOp = dyn_cast<heavy::SubpatternOp>(ParentOp))
627-
return BlockItrTy();
628+
return BlockItrTy(SubpatternOp);
628629

629630
assert(isa<heavy::PatternOp>(ParentOp) && "expecting a PatternOp.");
630631

@@ -636,27 +637,36 @@ class OpEvalImpl {
636637
heavy::Value P = Op.getVal().getValue(Context);
637638
heavy::Value E = getValue(Op.getInput());
638639
heavy::EnvEntry Entry;
639-
if (Symbol* S = dyn_cast<Symbol>(E)) {
640+
if (isIdentifier(E) && isIdentifier(P)) {
640641
// If the symbol is in the environment we can skip
641642
// because it will not match a literal.
642-
Entry = Context.Lookup(S);
643-
if (Entry)
644-
E = Entry.Value;
643+
heavy::EnvEntry EntryE = Context.Lookup(E);
644+
heavy::EnvEntry EntryP = Context.Lookup(P);
645+
if (EntryE && EntryP && EntryE.Value == EntryP.Value)
646+
return next(Op);
647+
else if (EntryE || EntryP)
648+
return patternFail(Op,
649+
"expecting identifier with identical binding: {1}", {E, P});
650+
// Unwrap syntax closures.
651+
P = Context.RebuildLiteral(P);
652+
E = Context.RebuildLiteral(E);
645653
}
646-
if (equal(P, E))
654+
if (equal(P, E)) {
647655
return next(Op);
648-
else
649-
return patternFail(Op, "expecting literal", {P, E});
656+
} else {
657+
Context.setLoc(E.getSourceLocation());
658+
return patternFail(Op, "expecting literal: {1}", {E, P});
659+
}
650660
}
651661

652662
BlockItrTy Visit(MatchPairOp Op) {
653663
heavy::Value E = getValue(Op.getInput());
654664
if (auto* Pair = dyn_cast<heavy::Pair>(E)) {
655-
Context.setLoc(E.getSourceLocation());
656665
setValue(Op.getCar(), Pair->Car);
657666
setValue(Op.getCdr(), Pair->Cdr);
658667
return next(Op);
659668
}
669+
Context.setLoc(E.getSourceLocation());
660670
return patternFail(Op, "expecting pair", E);
661671
}
662672

@@ -755,7 +765,7 @@ class OpEvalImpl {
755765

756766
BlockItrTy Visit(SubpatternOp Op) {
757767
heavy::Value E = getValue(Op.getInput());
758-
heavy::Pair* Tail = dyn_cast<heavy::Pair>(getValue(Op.getTail()));
768+
heavy::Value Tail = getValue(Op.getTail());
759769

760770
// Match the empty list.
761771
if (isa<heavy::Empty>(E)) {
@@ -775,16 +785,24 @@ class OpEvalImpl {
775785
// stopping when we find tail or a non-pair object.
776786
// Each "pack" should be a list.
777787
while (auto* Pair = dyn_cast<heavy::Pair>(E)) {
778-
if (Pair == Tail)
788+
if (E == Tail)
779789
break;
780790
auto Scope = ValueMapScope(ValueMap);
781791
BlockItrTy Itr = Op.getBody().front().begin();
782792
setValue(BodyArg, Pair->Car);
783-
while (Itr != BlockItrTy())
793+
while (Itr != BlockItrTy() && Itr != BlockItrTy(Op))
784794
Itr = Visit(&*Itr);
795+
// Check subpattern failure.
796+
if (Itr == BlockItrTy(Op))
797+
break;
785798
E = Pair->Cdr;
786799
}
787800

801+
// Fail if E still has junk that is not Tail and
802+
// did not get gobbled up by the subpattern.
803+
if (E != Tail)
804+
return patternFail(Op, "unexpected elements after subpattern", {E});
805+
788806
return next(Op);
789807
}
790808

@@ -820,7 +838,7 @@ class OpEvalImpl {
820838
if (P == E) {
821839
return next(Op);
822840
}
823-
return patternFail(Op, "identifier does not match", {P, E});
841+
return patternFail(Op, "identifier does not match", {E, P});
824842
}
825843

826844
BlockItrTy Visit(SyntaxClosureOp Op) {

0 commit comments

Comments
 (0)