Skip to content

Commit caafb6a

Browse files
committed
[Heavy] Consolidate syntax template generation
1 parent 7a0022a commit caafb6a

File tree

5 files changed

+84
-51
lines changed

5 files changed

+84
-51
lines changed

heavy/include/heavy/ValueVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ class ValueVisitor {
3131
RetTy Visit ## NAME(cast_ty<NAME> V, Args&& ...args) { \
3232
return getDerived().VisitValue(V, std::forward<Args>(args)...); }
3333

34+
protected:
3435
Derived& getDerived() { return static_cast<Derived&>(*this); }
3536
Derived const& getDerived() const { return static_cast<Derived>(*this); }
3637

37-
protected:
3838
// Derived must implement VisitValue OR all of the
3939
// concrete visitors
4040
template <typename T>

heavy/lib/OpGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ mlir::Value OpGen::createSyntaxRules(SourceLocation Loc,
677677
// Terminate with error call for when no patterns match.
678678
mlir::Value ErrorMsg = createLiteral(Loc,
679679
Context.CreateString("no matching pattern for syntax"));
680-
std::array<mlir::Value, 3> ErrorArgs{ErrorMsg, ExprArg, SyntaxOp.getResult()};
680+
std::array<mlir::Value, 2> ErrorArgs{ErrorMsg, ExprArg};
681681
createSyntaxError(Loc, ErrorArgs);
682682

683683
if (!isa<Empty>(PatternDefs) || Body.empty()) {

heavy/lib/PatternTemplate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class PatternTemplate : ValueVisitor<PatternTemplate, mlir::Value> {
8181

8282
if (!OpGen.CheckError()) {
8383
TemplateGen TG(OpGen, PatternVars, Ellipsis);
84-
TG.VisitTemplate(Template);
84+
TG.BuildTemplate(Template);
8585
}
8686

8787
return mlir::Value();

heavy/lib/TemplateBase.h

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <heavy/Context.h>
1616
#include <heavy/OpGen.h>
1717
#include <heavy/Value.h>
18+
#include <heavy/ValueVisitor.h>
1819
#include <mlir/IR/Value.h>
1920
#include <llvm/Support/Casting.h>
2021
#include <variant>
@@ -36,15 +37,30 @@ using TemplateResult = std::variant<TemplateError, mlir::Value,
3637
// utility functions for creating operations that build
3738
// literals interspersed with values.
3839
template <typename Derived>
39-
class TemplateBase {
40+
class TemplateBase : protected ValueVisitor<Derived, TemplateResult> {
4041
protected:
42+
friend ValueVisitor<Derived, TemplateResult>;
43+
using ValueVisitor<Derived, TemplateResult>::getDerived;
44+
4145
heavy::OpGen& OpGen;
4246
llvm::SmallVector<mlir::Value, 8> RenameOps;
4347

4448
TemplateBase(heavy::OpGen& O)
4549
: OpGen(O)
4650
{ }
4751

52+
// Template is the input "form" with syntax closures.
53+
mlir::Value VisitTemplate(heavy::Value Template) {
54+
TemplateResult Result = getDerived().Visit(Template);
55+
56+
mlir::Value TransformedSyntax;
57+
if (mlir::Value* MVP = std::get_if<mlir::Value>(&Result))
58+
TransformedSyntax = *MVP;
59+
else
60+
TransformedSyntax = createLiteral(std::get<heavy::Value>(Result));
61+
return TransformedSyntax;
62+
}
63+
4864
mlir::Location createLoc(heavy::SourceLocation Loc) {
4965
return mlir::OpaqueLoc::get(Loc.getOpaqueEncoding(),
5066
OpGen.Builder.getContext());
@@ -133,6 +149,58 @@ class TemplateBase {
133149
RenameOps.push_back(R);
134150
return LiteralResult;
135151
}
152+
153+
mlir::Value createRenameEnv() {
154+
heavy::SourceLocation Loc;
155+
// Move the RenameOps to the end of the current block.
156+
mlir::Block* Block = OpGen.Builder.getBlock();
157+
auto InsertionPoint = OpGen.Builder.getInsertionPoint();
158+
for (mlir::Value RV : RenameOps)
159+
RV.getDefiningOp()->moveBefore(Block, InsertionPoint);
160+
161+
// Add the RenameOps to an EnvFrame
162+
// to serve as the base of the environment.
163+
return OpGen.create<EnvFrameOp>(Loc, RenameOps);
164+
}
165+
166+
// ValueVisitor functions
167+
168+
TemplateResult VisitValue(heavy::Value P) {
169+
return P;
170+
}
171+
172+
TemplateResult VisitSyntaxClosure(SyntaxClosure* SC) {
173+
mlir::Value EnvVal = any_cast<mlir::Value>(SC->Env);
174+
if (!EnvVal)
175+
return OpGen.SetError("expecting env mlir.value");
176+
177+
if (!isa<Symbol>(SC->Node))
178+
return OpGen.SetError("expecting a symbol");
179+
180+
heavy::SourceLocation Loc = SC->Node.getSourceLocation();
181+
mlir::Value Node = createLiteral(SC->Node);
182+
return OpGen.create<SyntaxClosureOp>(Loc, Node, EnvVal);
183+
}
184+
185+
TemplateResult VisitPair(Pair* P) {
186+
heavy::SourceLocation Loc = P->getSourceLocation();
187+
188+
TemplateResult CarResult = getDerived().Visit(P->Car);
189+
TemplateResult CdrResult = getDerived().Visit(P->Cdr);
190+
191+
// If nothing changed
192+
auto* HCar = std::get_if<heavy::Value>(&CarResult);
193+
auto* HCdr = std::get_if<heavy::Value>(&CdrResult);
194+
if (HCar && HCdr && *HCar == P->Car && *HCdr == P->Cdr)
195+
return P;
196+
197+
if (OpGen.CheckError())
198+
return mlir::Value();
199+
200+
return createCons(Loc, CarResult, CdrResult);
201+
}
202+
203+
TemplateResult VisitSymbol(Symbol* P) = delete;
136204
};
137205

138206
}

heavy/lib/TemplateGen.h

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,16 @@ namespace heavy {
2424
// TemplateGen
2525
// - Substitute a template using the syntactic environment
2626
// created by a pattern as part of the syntax-rules syntax.
27-
class TemplateGen : TemplateBase<TemplateGen>,
28-
ValueVisitor<TemplateGen, TemplateResult> {
27+
class TemplateGen : TemplateBase<TemplateGen> {
2928
friend TemplateBase<TemplateGen>;
3029
friend ValueVisitor<TemplateGen, TemplateResult>;
30+
using Base = TemplateBase<TemplateGen>;
3131

3232
Symbol* Ellipsis;
3333
NameSet& PatternVarNames;
3434
llvm::SmallVectorImpl<mlir::Value>* CurrentPacks = nullptr;
3535

3636
public:
37-
using ResultTy = TemplateResult;
3837
using ErrorTy = TemplateError;
3938

4039
TemplateGen(heavy::OpGen& O, NameSet& PVNames,
@@ -44,40 +43,18 @@ class TemplateGen : TemplateBase<TemplateGen>,
4443
PatternVarNames(PVNames)
4544
{ }
4645

47-
// VisitTemplate - Create operations to transform syntax and
48-
// evaluate it.
49-
void VisitTemplate(heavy::Value Template) {
46+
// Create operations to transform syntax and compile the result.
47+
void BuildTemplate(heavy::Value Template) {
5048
heavy::SourceLocation Loc = Template.getSourceLocation();
51-
ResultTy Result = Visit(Template);
52-
53-
mlir::Value TransformedSyntax;
54-
if (mlir::Value* MVP = std::get_if<mlir::Value>(&Result))
55-
TransformedSyntax = *MVP;
56-
else
57-
TransformedSyntax = createLiteral(std::get<heavy::Value>(Result));
58-
59-
{
60-
// Move the RenameOps to the end of the PatternOp.
61-
mlir::Block* Block = OpGen.Builder.getBlock();
62-
auto InsertionPoint = OpGen.Builder.getInsertionPoint();
63-
for (mlir::Value RV : RenameOps)
64-
RV.getDefiningOp()->moveBefore(Block, InsertionPoint);
65-
}
66-
67-
// Add the RenameOps to an EnvFrame
68-
// to serve as the base of the environment.
69-
mlir::Value TemplateEnv = OpGen.create<EnvFrameOp>(Loc, RenameOps);
70-
OpGen.createOpGen(Loc, TransformedSyntax, TemplateEnv);
49+
mlir::Value TransformedSyntax = VisitTemplate(Template);
50+
mlir::Value RenameEnv = createRenameEnv();
51+
OpGen.createOpGen(Loc, TransformedSyntax, RenameEnv);
7152
}
7253

7354
private:
74-
ResultTy VisitValue(Value P) {
75-
return P;
76-
}
77-
7855
mlir::Value ExpandPack(heavy::SourceLocation Loc,
7956
heavy::Value Car, heavy::Value Cdr) {
80-
ResultTy CdrResult = Visit(Cdr);
57+
TemplateResult CdrResult = Visit(Cdr);
8158
auto Body = std::make_unique<mlir::Region>();
8259
llvm::SmallVector<mlir::Value, 4> Packs;
8360
llvm::SmallVectorImpl<mlir::Value>* PrevPacks = CurrentPacks;
@@ -86,7 +63,7 @@ class TemplateGen : TemplateBase<TemplateGen>,
8663
{
8764
mlir::OpBuilder::InsertionGuard IG(OpGen.Builder);
8865
OpGen.Builder.setInsertionPointToStart(&Block);
89-
ResultTy Last = Visit(Car);
66+
TemplateResult Last = Visit(Car);
9067
if (OpGen.CheckError())
9168
return mlir::Value();
9269
if (std::holds_alternative<heavy::Value>(Last))
@@ -100,27 +77,15 @@ class TemplateGen : TemplateBase<TemplateGen>,
10077
return EPO.getResult();
10178
}
10279

103-
ResultTy VisitPair(Pair* P) {
80+
TemplateResult VisitPair(Pair* P) {
10481
heavy::SourceLocation Loc = P->getSourceLocation();
10582
if (auto* P2 = dyn_cast<Pair>(P->Cdr);
10683
P2 && isa<Symbol>(P2->Car) &&
10784
cast<Symbol>(P2->Car)->Equiv(Ellipsis)) {
10885
return ExpandPack(Loc, P->Car, P2->Cdr);
10986
}
11087

111-
ResultTy CarResult = Visit(P->Car);
112-
ResultTy CdrResult = Visit(P->Cdr);
113-
114-
// If nothing changed
115-
auto* HCar = std::get_if<heavy::Value>(&CarResult);
116-
auto* HCdr = std::get_if<heavy::Value>(&CdrResult);
117-
if (HCar && HCdr && *HCar == P->Car && *HCdr == P->Cdr)
118-
return P;
119-
120-
if (OpGen.CheckError())
121-
return mlir::Value();
122-
123-
return createCons(Loc, CarResult, CdrResult);
88+
return Base::VisitPair(P);
12489
}
12590

12691
mlir::Value GetPatternVar(heavy::Symbol* S) {
@@ -161,7 +126,7 @@ class TemplateGen : TemplateBase<TemplateGen>,
161126
return Block->addArgument(HeavyValueT, MLoc);
162127
}
163128

164-
ResultTy VisitSymbol(Symbol* P) {
129+
TemplateResult VisitSymbol(Symbol* P) {
165130
if (PatternVarNames.contains(P->getString()))
166131
return CurrentPacks ? CaptureExpandArg(P) : GetPatternVar(P);
167132

0 commit comments

Comments
 (0)