Skip to content

Commit 11f67f7

Browse files
committed
[Heavy] Fix defines with syntax closures
1 parent 165ffc3 commit 11f67f7

File tree

6 files changed

+73
-38
lines changed

6 files changed

+73
-38
lines changed

heavy/include/heavy/Context.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,7 @@ class Context : public ContinuationStack<Context>,
391391
}
392392

393393
SyntaxClosure* CreateSyntaxClosure(SourceLocation Loc, Value Node,
394-
Value Env) {
395-
return new (*this) SyntaxClosure(Loc, Env, Node);
396-
}
394+
Value Env);
397395

398396
SourceValue* CreateSourceValue(SourceLocation Loc) {
399397
return new (*this) SourceValue(Loc);

heavy/include/heavy/OpGen.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
300300

301301
mlir::Value createGlobal(SourceLocation Loc, llvm::StringRef MangledName);
302302
mlir::Value createBinding(Binding *B, mlir::Value Init);
303-
mlir::Value createDefine(Symbol* S, Value Args, Value OrigCall);
304-
mlir::Value createTopLevelDefine(Symbol* S, Value Args, Value OrigCall);
303+
mlir::Value createDefine(Value Id, Value Args, Value OrigCall);
304+
mlir::Value createTopLevelDefine(Value Id, Value Args, Value OrigCall);
305305
mlir::Value createUndefined();
306306
mlir::Value createSet(SourceLocation Loc, Value LHS, Value RHS);
307307

heavy/lib/Builtins.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,33 @@ void op_eval(Context& C, ValueRefs Args);
113113

114114
mlir::Value define(OpGen& OG, Pair* P) {
115115
Pair* P2 = dyn_cast<Pair>(P->Cdr);
116-
Symbol* S = nullptr;
116+
Value Id = nullptr;
117117

118118
if (!P2)
119119
return OG.SetError("invalid syntax for define", P);
120120
if (Pair* LambdaSpec = dyn_cast<Pair>(P2->Car))
121-
S = dyn_cast<Symbol>(LambdaSpec->Car);
121+
Id = LambdaSpec->Car;
122122
else
123-
S = dyn_cast<Symbol>(P2->Car);
123+
Id = P2->Car;
124124

125-
if (!S)
125+
if (!isIdentifier(Id))
126126
return OG.SetError("invalid syntax for define", P);
127-
return OG.createDefine(S, P2, P);
127+
return OG.createDefine(Id, P2, P);
128128
}
129129

130130
mlir::Value define_syntax(OpGen& OG, Pair* P) {
131131
Pair* P2 = dyn_cast<Pair>(P->Cdr);
132132
if (!P2) return OG.SetError("invalid define-syntax syntax", P);
133-
Symbol* S = dyn_cast<Symbol>(P2->Car);
134-
if (!S) return OG.SetError("expecting name for define-syntax", P);
133+
Value Id = P2->Car;
134+
if (!isIdentifier(Id))
135+
return OG.SetError("expecting name for define-syntax", P);
135136

136137
return OG.createSyntaxSpec(P2, P);
137138
}
138139

139140
mlir::Value syntax_rules(OpGen& OG, Pair* P) {
141+
// TODO Support SyntaxClosures.
142+
140143
// The input is the <Syntax Spec> (Keyword (syntax-rules ...))
141144
// <Syntax Spec> has its own checks in createSyntaxSpec
142145
Symbol* Keyword = dyn_cast<Symbol>(P->Car);

heavy/lib/Context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,6 +1628,14 @@ EnvFrame* Context::CreateEnvFrame(llvm::ArrayRef<Value> Ids) {
16281628
return E;
16291629
}
16301630

1631+
SyntaxClosure* Context::CreateSyntaxClosure(SourceLocation Loc, Value Node,
1632+
Value Env) {
1633+
// Prevent direct nesting of SyntaxClosures.
1634+
if (auto* SC = dyn_cast<SyntaxClosure>(Node))
1635+
return SC;
1636+
return new (*this) SyntaxClosure(Loc, Env, Node);
1637+
}
1638+
16311639
bool Context::OutputModule(llvm::StringRef MangledName,
16321640
llvm::StringRef ModulePath) {
16331641
#if 0

heavy/lib/OpGen.cpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -771,32 +771,44 @@ mlir::Value OpGen::createBinding(Binding *B, mlir::Value Init) {
771771
return BVal;
772772
}
773773

774-
mlir::Value OpGen::createDefine(Symbol* S, Value DefineArgs,
774+
mlir::Value OpGen::createDefine(Value Id, Value DefineArgs,
775775
Value OrigCall) {
776-
if (isTopLevel()) return createTopLevelDefine(S, DefineArgs, OrigCall);
776+
if (isTopLevel()) return createTopLevelDefine(Id, DefineArgs, OrigCall);
777777
if (!IsLocalDefineAllowed) return SetError("unexpected define", OrigCall);
778778
// Create the binding with a lazy init.
779779
// (Include everything after the define
780780
// keyword to visit it later because it could
781781
// be a terse lambda syntax.)
782-
Binding* B = Context.CreateBinding(S, DefineArgs);
782+
Binding* B = Context.CreateBinding(Id, DefineArgs);
783783
// Push to the local environment.
784784
Context.PushLocalBinding(B);
785785
mlir::Value BVal = createBinding(B, mlir::Value());
786786
assert(IsLocalDefineAllowed && "define should still be allowed");
787787
return BVal;
788788
}
789789

790-
mlir::Value OpGen::createTopLevelDefine(Symbol* S, Value DefineArgs,
790+
mlir::Value OpGen::createTopLevelDefine(Value Id, Value DefineArgs,
791791
Value OrigCall) {
792792
SourceLocation DefineLoc = OrigCall.getSourceLocation();
793793
if (LibraryEnvProc) {
794794
return SetError("unexpected define", OrigCall);
795795
}
796796

797+
assert(isTopLevel() && "expecting top level");
798+
Environment* Env = nullptr;
799+
if (auto* SC = dyn_cast<SyntaxClosure>(Id))
800+
Env = cast<Environment>(SC->Env);
801+
else if (CurSyntaxClosure)
802+
Env = cast<Environment>(CurSyntaxClosure->Env);
803+
else
804+
Env = cast<Environment>(Context.EnvStack);
805+
806+
// Unwrap SyntaxClosures.
807+
assert(isIdentifier(Id) && "expecting identifier");
808+
Symbol* S = cast<Symbol>(Context.RebuildLiteral(Id));
809+
797810
// If EnvStack isn't an Environment then there is local
798811
// scope information on top of it
799-
Environment* Env = cast<Environment>(Context.EnvStack);
800812

801813
EnvEntry Entry = Env->Lookup(Context, S);
802814
if (Entry.Value && Entry.MangledName) {
@@ -1336,32 +1348,30 @@ void OpGen::Export(Value NameList) {
13361348
return Context.Cont();
13371349
}
13381350

1339-
// Expect a Symbol or SyntaxClosure wrapping a Symbol.
13401351
heavy::EnvEntry OpGen::LookupEnv(heavy::Value Id) {
1352+
assert(isIdentifier(Id) && "expecting an identifier");
1353+
// For any given SyntaxClosure, use the object as the lookup
1354+
// in the current environment, and then use the raw Symbol
1355+
// in the closed environment.
13411356
heavy::EnvEntry Result;
1342-
if (CurSyntaxClosure) {
1343-
SyntaxClosure StackSC;
1344-
SyntaxClosure* SC = nullptr;
1345-
// Id could be a Symbol or a closed (wrapped) Symbol.
1346-
Symbol* S = dyn_cast<Symbol>(Id);
1347-
if (S) {
1348-
StackSC.Env = CurSyntaxClosure->Env;
1349-
StackSC.Node = S;
1350-
SC = &StackSC;
1351-
} else {
1352-
SC = cast<SyntaxClosure>(Id);
1353-
S = cast<Symbol>(SC->Node);
1354-
}
1357+
heavy::Value ClosedEnv;
1358+
if (auto* SC = dyn_cast<SyntaxClosure>(Id)) {
1359+
ClosedEnv = SC->Env;
1360+
Id = SC->Node;
1361+
} else if (CurSyntaxClosure) {
1362+
ClosedEnv = CurSyntaxClosure->Env;
1363+
}
13551364

1356-
// Invalid input
1357-
assert((SC && S) && "expecting identifier");
1365+
assert(isa<Symbol>(Id) && "syntax closure should be unwrapped");
13581366

1359-
// Perform lookup using the SyntaxClosure object in the
1360-
// primary EnvStack. Then use the raw symbol and look
1361-
// in the closed environment.
1362-
Result = Context.Lookup(SC);
1367+
if (ClosedEnv) {
1368+
SyntaxClosure StackSC;
1369+
Symbol* S = cast<Symbol>(Id);
1370+
StackSC.Env = ClosedEnv;
1371+
StackSC.Node = S;
1372+
Result = Context.Lookup(&StackSC);
13631373
if (!Result)
1364-
Result = Context.Lookup(S, SC->Env);
1374+
Result = Context.Lookup(S, ClosedEnv);
13651375
} else {
13661376
Symbol* S = cast<Symbol>(Id);
13671377
Result = Context.Lookup(S);

heavy/test/Evaluate/define-syntax.scm

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,19 @@
7474
'(0 1 i ... 9))))
7575
(write (ez 0 1 2 3 4 5 6))
7676
(newline)
77+
78+
(define-syntax my-define
79+
(syntax-rules ()
80+
((my-define name x)
81+
(define name '(my name x)))))
82+
83+
; CHECK: (my my-tl 42)
84+
(my-define my-tl 42)
85+
(write my-tl)
86+
(newline)
87+
88+
; CHECK: Undefined
89+
((lambda ()
90+
(my-define not-my-local 12)
91+
(write not-my-local)
92+
(newline)))

0 commit comments

Comments
 (0)