Skip to content

Commit 0f06425

Browse files
committed
[Heavy] Finish subpattern expansion
1 parent 94d094d commit 0f06425

File tree

6 files changed

+130
-30
lines changed

6 files changed

+130
-30
lines changed

heavy/include/heavy/Ops.td

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,23 +378,46 @@ def heavy_MatchPairOp : HeavyOp<"match_pair", []> {
378378
];
379379
}
380380

381+
def heavy_MatchTailOp : HeavyOp<"match_tail", []> {
382+
let description = [{
383+
Destructure the last N lengthed list of the tail of the input list.
384+
The last cdr is counted in the length.
385+
(ie () or any element for improper list.)
386+
The tail will be matched against a pattern.
387+
If the tail is a pair, its address we serve as a sentinel
388+
for terminating the matching of elements against the subpattern.
389+
}];
390+
391+
let arguments = (ins UI32Attr:$length,
392+
HeavyValue:$input);
393+
let results = (outs HeavyValue:$tail);
394+
395+
let builders = [
396+
OpBuilder<(ins "uint32_t":$length,
397+
"::mlir::Value":$input)>
398+
];
399+
}
400+
381401
def heavy_SubpatternOp : HeavyOp<"subpattern", []> {
382402
let description = [{
383403
Match a list of elements over a subpattern that
384404
precedes the `...` ellipsis. Each pattern variable within
385405
the list is returned as a pack list that may only be expanded
386406
by a subtemplate.
387407

388-
The element after the matched list
389-
is destructured as `cdr` similar to match_pair.
408+
We stop at the elements that occur in the list after the
409+
subpattern by passing the tail value as a sentinel.
410+
It checks by the address of the pair or any cdr that is
411+
not a pair.
390412
}];
391413

392-
let arguments = (ins HeavyValue:$input);
393-
let results = (outs Variadic<HeavyValue>:$packs, HeavyValue:$cdr);
414+
let arguments = (ins HeavyValue:$input, HeavyValue:$tail);
415+
let results = (outs Variadic<HeavyValue>:$packs);
394416
let regions = (region SizedRegion<1>:$body);
395417

396418
let builders = [
397419
OpBuilder<(ins "::mlir::Value":$input,
420+
"::mlir::Value":$tail,
398421
"std::unique_ptr<::mlir::Region>":$body,
399422
"unsigned":$num_packs)>
400423
];

heavy/lib/Dialect.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,20 @@ void MatchPairOp::build(mlir::OpBuilder& B, mlir::OperationState& OpState,
181181
input);
182182
}
183183

184+
void MatchTailOp::build(mlir::OpBuilder& B, mlir::OperationState& OpState,
185+
uint32_t Length, mlir::Value Input) {
186+
mlir::Type HeavyValueT = B.getType<HeavyValueTy>();
187+
MatchTailOp::build(B, OpState, HeavyValueT,
188+
Length, Input);
189+
}
190+
184191
void SubpatternOp::build(mlir::OpBuilder& B, mlir::OperationState& OpState,
185-
mlir::Value Input,
192+
mlir::Value Input, mlir::Value Tail,
186193
std::unique_ptr<mlir::Region> Body,
187194
unsigned NumPacks) {
188195
mlir::Type HeavyValueT = B.getType<HeavyValueTy>();
189-
// NumPacks + 1 because $cdr is also a result value.
190-
llvm::SmallVector<mlir::Type, 4> ResultTypes(NumPacks + 1, HeavyValueT);
191-
OpState.addOperands(Input);
196+
llvm::SmallVector<mlir::Type, 4> ResultTypes(NumPacks, HeavyValueT);
197+
OpState.addOperands({Input, Tail});
192198
OpState.addRegion(std::move(Body));
193199
OpState.addTypes(std::move(ResultTypes));
194200
}

heavy/lib/OpEval.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ class OpEvalImpl {
206206
else if (isa<FuncOp>(Op)) return next(Op); // skip functions
207207
else if (isa<SyntaxOp>(Op)) return Visit(cast<SyntaxOp>(Op));
208208
else if (isa<MatchPairOp>(Op)) return Visit(cast<MatchPairOp>(Op));
209+
else if (isa<MatchTailOp>(Op)) return Visit(cast<MatchTailOp>(Op));
209210
else if (isa<SubpatternOp>(Op)) return Visit(cast<SubpatternOp>(Op));
210211
else if (isa<ExpandPacksOp>(Op)) return Visit(cast<ExpandPacksOp>(Op));
211212
else if (isa<ResolveOp>(Op)) return Visit(cast<ResolveOp>(Op));
@@ -622,7 +623,8 @@ class OpEvalImpl {
622623
// or, in the case of a pack, it begins matching the rest of the list.
623624
BlockItrTy patternFail(mlir::Operation* Op) {
624625
// We should currently be in the scope of a PatternOp or SubpatternOp
625-
assert((isa<MatchOp, MatchPairOp, SubpatternOp, MatchIdOp>(Op)) &&
626+
assert((isa<MatchOp, MatchPairOp, MatchTailOp,
627+
SubpatternOp, MatchIdOp>(Op)) &&
626628
"Operation must be a pattern matcher");
627629

628630
mlir::Operation* PatternOp = Op->getParentOp();
@@ -672,6 +674,27 @@ class OpEvalImpl {
672674
return patternFail(Op);
673675
}
674676

677+
BlockItrTy Visit(MatchTailOp Op) {
678+
// Here we include the last cdr in the length
679+
// of a list proper or improper.
680+
heavy::Value E = getValue(Op.getInput());
681+
uint32_t TargetLen = Op.getLength();
682+
assert(TargetLen >= 1 && "expecting positive length");
683+
uint32_t TotalLen = 1;
684+
heavy::Value Cur = E;
685+
while (auto* Pair = dyn_cast<heavy::Pair>(Cur)) {
686+
++TotalLen;
687+
Cur = Pair->Cdr;
688+
}
689+
if (TotalLen < TargetLen)
690+
return patternFail(Op);
691+
Cur = E;
692+
for (uint32_t i = 0; i < TotalLen - TargetLen; i++)
693+
Cur = cast<heavy::Pair>(Cur)->Cdr;
694+
setValue(Op.getTail(), Cur);
695+
return next(Op);
696+
}
697+
675698
BlockItrTy Visit(ExpandPacksOp Op) {
676699
heavy::SourceLocation Loc = getSourceLocation(Op.getLoc());
677700
setValue(Op.getResult(), getValue(Op.getCdr()));
@@ -708,6 +731,7 @@ class OpEvalImpl {
708731

709732
BlockItrTy Visit(SubpatternOp Op) {
710733
heavy::Value E = getValue(Op.getInput());
734+
heavy::Pair* Tail = dyn_cast<heavy::Pair>(getValue(Op.getTail()));
711735

712736
// Match the empty list.
713737
if (isa<heavy::Empty>(E)) {
@@ -723,19 +747,20 @@ class OpEvalImpl {
723747
assert(Op.getBody().getNumArguments() == 1
724748
&& "body should have single argument");
725749
mlir::Value BodyArg = Op.getBody().getArgument(0);
726-
// Visit the subpattern body for each element in E.
750+
// Visit the subpattern body for each element in E
751+
// stopping when we find tail or a non-pair object.
727752
// Each "pack" should be a list.
728753
while (auto* Pair = dyn_cast<heavy::Pair>(E)) {
754+
if (Pair == Tail)
755+
break;
729756
push_scope();
730757
BlockItrTy Itr = Op.getBody().front().begin();
731-
setValue(BodyArg, Pair);
758+
setValue(BodyArg, Pair->Car);
732759
while (Itr != BlockItrTy())
733760
Itr = Visit(&*Itr);
734761
E = Pair->Cdr;
735762
}
736763

737-
setValue(Op.getCdr(), E);
738-
739764
return next(Op);
740765
}
741766

@@ -762,7 +787,7 @@ class OpEvalImpl {
762787
}
763788

764789
auto ExpandPacksOp = cast<heavy::ExpandPacksOp>(Op->getParentOp());
765-
assert(Op.getArgs().size() == 1 && "expecting single result");
790+
assert(Op.getArgs().size() == 1 && "expecting single result");
766791
heavy::Value CurrentResult = getValue(Op.getArgs().front());
767792
pop_scope();
768793
heavy::Value Result = Context.CreatePair(CurrentResult,

heavy/lib/PatternTemplate.h

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class PatternTemplate : ValueVisitor<PatternTemplate, mlir::Value> {
6161
// VisitPatternTemplate should be called with OpGen's insertion point in
6262
// the body of PatternOp
6363
mlir::Value VisitPatternTemplate(heavy::Value Pattern,
64-
heavy::Value Template,
64+
heavy::Value Template,
6565
mlir::Value E) {
6666
heavy::SourceLocation Loc = Pattern.getSourceLocation();
6767
if (isa_and_nonnull<Symbol>(Pattern.car())) {
@@ -106,8 +106,33 @@ class PatternTemplate : ValueVisitor<PatternTemplate, mlir::Value> {
106106
return OpGen.SetError("invalid pattern node", P);
107107
}
108108

109+
mlir::Value VisitTail(heavy::SourceLocation Loc,
110+
Value P, mlir::Value E) {
111+
// Match the list after the `...`.
112+
uint32_t Length = 1;
113+
114+
if (Pair* P2 = dyn_cast<Pair>(P))
115+
Loc = P2->getSourceLocation();
116+
117+
heavy::Value Cur = P;
118+
while (Pair* P2 = dyn_cast<Pair>(Cur)) {
119+
++Length;
120+
Cur = P2->Cdr;
121+
}
122+
mlir::Value Tail = OpGen.create<MatchTailOp>(Loc, Length, E);
123+
124+
// Actually match the tail part.
125+
Visit(P, Tail);
126+
127+
// The tail is passed to the SubpatternOp.
128+
return Tail;
129+
}
130+
109131
mlir::Value VisitSubpattern(heavy::SourceLocation Loc,
110132
Value P, Value Cdr, mlir::Value E) {
133+
// The Tail will be used as a sentinel value if it is a pair.
134+
mlir::Value Tail = VisitTail(Loc, Cdr, E);
135+
111136
// Visit the subpattern.
112137
auto Body = std::make_unique<mlir::Region>();
113138
llvm::SmallVector<mlir::Value, 4> Packs;
@@ -120,7 +145,7 @@ class PatternTemplate : ValueVisitor<PatternTemplate, mlir::Value> {
120145
mlir::Value BodyArg = Block.addArgument(HeavyValueT, MLoc);
121146
Visit(P, BodyArg);
122147

123-
// Create range for pack values by finding the
148+
// Create range for pack values by finding the
124149
// SyntaxClosureOps in the Body region.
125150
for (mlir::Operation& Op : Block) {
126151
if (auto SC = dyn_cast<SyntaxClosureOp>(&Op))
@@ -132,11 +157,8 @@ class PatternTemplate : ValueVisitor<PatternTemplate, mlir::Value> {
132157
}
133158

134159
// Create the SubpatternOp.
135-
auto SubpatternOp = OpGen.create<heavy::SubpatternOp>(
136-
Loc, E, std::move(Body), Packs.size());
137-
138-
if (!OpGen.CheckError())
139-
Visit(Cdr, SubpatternOp.getCdr());
160+
OpGen.create<heavy::SubpatternOp>(
161+
Loc, E, Tail, std::move(Body), Packs.size());
140162

141163
// In the template, the nested syntax closures
142164
// will be looked up and check if its parent is
@@ -169,7 +191,7 @@ class PatternTemplate : ValueVisitor<PatternTemplate, mlir::Value> {
169191
if (P->equals("_")) {
170192
// Since _ always matches anything, there is
171193
// nothing to check.
172-
return mlir::Value();
194+
return mlir::Value();
173195
}
174196

175197
// <pattern identifier> (literal identifier)

heavy/lib/TemplateGen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TemplateGen : TemplateBase<TemplateGen>,
4949
void VisitTemplate(heavy::Value Template) {
5050
heavy::SourceLocation Loc = Template.getSourceLocation();
5151
ResultTy Result = Visit(Template);
52-
52+
5353
mlir::Value TransformedSyntax;
5454
if (mlir::Value* MVP = std::get_if<mlir::Value>(&Result))
5555
TransformedSyntax = *MVP;

heavy/test/Evaluate/define-syntax.scm

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
; RUN: heavy-scheme %s | FileCheck %s
22
(import (heavy base))
33

4+
;; TODO Local let-syntax needs to be tested/supported.
5+
46
(define ok 'ok!)
57
(define-syntax my-lambda
68
(syntax-rules (=>)
@@ -17,16 +19,38 @@
1719
'oops!)
1820
(newline)
1921

20-
;; TODO Local let-syntax needs to be supported.
21-
22-
#|
22+
; CHECK: 42"x has type Int""y has type Int"
23+
; CHECK-NEXT: 5
24+
; CHECK-NEXT: 6
2325
(define-syntax my-lambda
2426
(syntax-rules (:)
25-
((my-lambda (arg ...) body ...)
27+
((my-lambda ((arg : type) ...) body ...)
2628
(lambda (arg ...)
29+
(write 42)
2730
(write (string-append 'arg " has type " 'type)) ...
31+
(newline)
2832
body ...))))
29-
((my-lambda ((x) (y))
33+
((my-lambda ((x : Int) (y : Int))
3034
(write x)
31-
(write y)))
32-
|#
35+
(newline)
36+
(write y)) 5 6)
37+
(newline)
38+
39+
; CHECK: (0 1)
40+
; CHECK-NEXT: (0 1 2)
41+
(define-syntax ez
42+
(syntax-rules ()
43+
((ez 0 1 i ...)
44+
'(0 1 i ...))))
45+
(write (ez 0 1))
46+
(newline)
47+
(write (ez 0 1 2))
48+
(newline)
49+
50+
; CHECK: (0 1 2 3 4 9)
51+
(define-syntax ez
52+
(syntax-rules ()
53+
((ez 0 1 i ... 5 6)
54+
'(0 1 i ... 9))))
55+
(write (ez 0 1 2 3 4 5 6))
56+
(newline)

0 commit comments

Comments
 (0)