Skip to content

Commit 7a0022a

Browse files
committed
[Heavy] Improve OpEval scope mapping; Fix OpEval ownership
1 parent 62066a2 commit 7a0022a

File tree

3 files changed

+41
-64
lines changed

3 files changed

+41
-64
lines changed

heavy/include/heavy/Context.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ class Context;
5454
void compile(Context&, Value V, Value Env, Value Handler);
5555
void eval(Context&, Value V, Value Env);
5656
void write(llvm::raw_ostream&, Value);
57+
void opEval(mlir::Operation*);
5758

5859
class OpEvalImpl;
59-
void opEval(mlir::Operation*);
60+
struct OpEvalDeleter {
61+
void operator()(OpEvalImpl*) const;
62+
};
63+
using OpEvalPtr = std::unique_ptr<OpEvalImpl, OpEvalDeleter>;
6064

6165
class ContextLocalLookup {
6266
friend struct ContextLocal;
@@ -120,8 +124,7 @@ class Context : public ContinuationStack<Context>,
120124

121125
public:
122126
heavy::OpGen* OpGen = nullptr;
123-
// FIXME OpEval is not cleaned up or owned by anything.
124-
heavy::OpEvalImpl* OpEval = nullptr;
127+
heavy::OpEvalPtr OpEval;
125128

126129
// Work around DidCallContinuation being set with compiler errors.
127130
bool CheckError();

heavy/lib/Context.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "llvm/Support/raw_ostream.h"
4848
#include <algorithm>
4949
#include <cstring>
50+
#include <memory>
5051
#include <string>
5152

5253
using namespace heavy;
@@ -901,7 +902,7 @@ namespace {
901902

902903
// Append to Output the path to a module sans any file extension.
903904
void getModulePath(heavy::Context& C, llvm::StringRef MangledName,
904-
llvm::SmallVectorImpl<char>& Output) {
905+
llvm::SmallVectorImpl<char>& Output) {
905906
llvm::StringRef ModulePath =
906907
HEAVY_BASE_VAR(module_path).get(C).getStringRef();
907908
if (!ModulePath.empty())
@@ -1280,7 +1281,7 @@ void Context::RaiseError(String* Msg, llvm::ArrayRef<Value> IrrArgs) {
12801281
}
12811282

12821283
// ManagedObjectWind - Manage the lifetime of a C++ object within a dynamic
1283-
// extent via a provided type-erased desctructor.
1284+
// extent via a provided type-erased desctructor.
12841285
void Context::ManagedObjectWind(void* Ptr, DestructorTy Destructor,
12851286
Value Before, Value Thunk, Value After) {
12861287
// Sentinel is referenced by each lambda

heavy/lib/OpEval.cpp

Lines changed: 32 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/ADT/DenseMap.h"
1818
#include "llvm/ADT/ScopedHashTable.h"
1919
#include "llvm/Support/Casting.h"
20+
#include <memory>
2021
#include <stack>
2122

2223
namespace heavy {
@@ -28,14 +29,23 @@ class OpEvalImpl {
2829

2930
heavy::Context& Context;
3031
ValueMapTy ValueMap;
31-
std::stack<ValueMapScope> ValueMapScopes;
3232

3333
void setValue(mlir::Value M, heavy::Value H) {
3434
assert(M && "must be set to a valid value");
3535
assert(H && "must be set to a valid value");
3636
ValueMap.insert(M, H);
3737
}
3838

39+
void setValueParentScope(mlir::Value M, heavy::Value H) {
40+
assert(M && "must be set to a valid value");
41+
assert(H && "must be set to a valid value");
42+
ValueMapScope* CurScope = ValueMap.getCurScope();
43+
assert(CurScope && "expecting a scope");
44+
ValueMapScope* ParentScope = CurScope->getParentScope();
45+
assert(ParentScope && "expecting parent scope");
46+
ValueMap.insertIntoScope(ParentScope, M, H);
47+
}
48+
3949
heavy::Value getBindingOrValue(mlir::Value M) {
4050
assert(M && "lookup requires a valid value");
4151
heavy::Value V = ValueMap.lookup(M);
@@ -76,22 +86,12 @@ class OpEvalImpl {
7686
public:
7787
OpEvalImpl(heavy::Context& C)
7888
: Context(C),
79-
ValueMap(),
80-
ValueMapScopes()
81-
{
82-
// there has to be at least one scope on the stack
83-
ValueMapScopes.emplace(ValueMap);
84-
}
89+
ValueMap()
90+
{ }
8591

8692
// Prevent copy/move since we capture this in a few lambdas.
8793
OpEvalImpl(OpEvalImpl const&) = delete;
8894

89-
~OpEvalImpl() {
90-
// Pop the scopes in order.
91-
while (!ValueMapScopes.empty())
92-
ValueMapScopes.pop();
93-
}
94-
9595
void Eval(mlir::Operation* Op) {
9696
if (isa<GlobalOp, CommandOp, LoadModuleOp>(Op)) {
9797
Visit(Op);
@@ -119,16 +119,6 @@ class OpEvalImpl {
119119
}
120120

121121
private:
122-
void push_scope() {
123-
ValueMapScopes.emplace(ValueMap);
124-
}
125-
126-
void pop_scope() {
127-
ValueMapScopes.pop();
128-
assert(ValueMapScopes.size() > 0 &&
129-
"scope stack must be balanced");
130-
}
131-
132122
BlockItrTy next(mlir::Operation* Op) {
133123
return ++BlockItrTy(Op);
134124
}
@@ -243,7 +233,7 @@ class OpEvalImpl {
243233
}
244234
}
245235

246-
push_scope();
236+
auto Scope = ValueMapScope(ValueMap);
247237
mlir::Block& Body = F.getBody().front();
248238
if (NumParams != NumParamsMax)
249239
LoadArgs(Body, Args);
@@ -253,7 +243,6 @@ class OpEvalImpl {
253243
BlockItrTy Itr = Body.begin();
254244
while (Itr != BlockItrTy())
255245
Itr = Visit(&*Itr);
256-
pop_scope();
257246

258247
// The terminator operation should have
259248
// called Context.Apply() or one of those
@@ -483,20 +472,20 @@ class OpEvalImpl {
483472
C.Cont(Undefined());
484473
}, ValueRefs());
485474

486-
push_scope();
475+
auto Scope = ValueMapScope(ValueMap);
487476
BlockItrTy Itr = Op.getInitializer().front().begin();
488477
while (Itr != BlockItrTy())
489478
Itr = Visit(&*Itr);
490-
pop_scope();
479+
491480
return BlockItrTy();
492481
}
493482

494483
BlockItrTy Visit(CommandOp Op) {
495-
push_scope();
484+
auto Scope = ValueMapScope(ValueMap);
496485
BlockItrTy Itr = Op.getBody().front().begin();
497486
while (Itr != BlockItrTy())
498487
Itr = Visit(&*Itr);
499-
pop_scope();
488+
500489
return BlockItrTy();
501490
}
502491

@@ -672,7 +661,7 @@ class OpEvalImpl {
672661
// Recall that each pack is a list in reverse order.
673662
while (isa<heavy::Pair>(Packs.front())) {
674663
// Set each block argument to the cdr of each pack.
675-
push_scope();
664+
auto Scope = ValueMapScope(ValueMap);
676665
for (unsigned i = 0; i < Packs.size(); i++) {
677666
auto* Pair = dyn_cast<heavy::Pair>(Packs[i]);
678667
if (!Pair)
@@ -683,7 +672,6 @@ class OpEvalImpl {
683672
BlockItrTy Itr = Op.getBody().front().begin();
684673
while (Itr != BlockItrTy())
685674
Itr = Visit(&*Itr);
686-
// ResolveOp calls pop_scope.
687675
}
688676

689677
for (unsigned i = 0; i < Packs.size(); i++) {
@@ -718,13 +706,12 @@ class OpEvalImpl {
718706
while (auto* Pair = dyn_cast<heavy::Pair>(E)) {
719707
if (Pair == Tail)
720708
break;
721-
push_scope();
709+
auto Scope = ValueMapScope(ValueMap);
722710
BlockItrTy Itr = Op.getBody().front().begin();
723711
setValue(BodyArg, Pair->Car);
724712
while (Itr != BlockItrTy())
725713
Itr = Visit(&*Itr);
726714
E = Pair->Cdr;
727-
// ResolveOp calls pop_scope.
728715
}
729716

730717
return next(Op);
@@ -736,21 +723,14 @@ class OpEvalImpl {
736723
assert(ParentOp.getPacks().size() ==
737724
Op.getArgs().size() &&
738725
"expecting value ranges of equal size");
739-
// Store packs on the stack until we can pop the scope.
740-
llvm::SmallVector<heavy::Value, 4> Packs(
741-
ParentOp.getPacks().size(), nullptr);
742726
mlir::ValueRange ResolveArgs = Op.getArgs();
743727
mlir::ResultRange MResults = ParentOp.getPacks();
744728
// Construct each pack as a reversed ordered list.
745-
for (unsigned i = 0; i < Packs.size(); i++)
746-
Packs[i] = Context.CreatePair(getValue(ResolveArgs[i]),
747-
getValue(MResults[i]));
748-
749-
// Set the results in the parent scope.
750-
pop_scope();
751-
for (unsigned i = 0; i < Packs.size(); i++)
752-
setValue(MResults[i], Packs[i]);
753-
729+
for (unsigned i = 0; i < ResolveArgs.size(); i++) {
730+
Value Pack = Context.CreatePair(getValue(ResolveArgs[i]),
731+
getValue(MResults[i]));
732+
setValueParentScope(MResults[i], Pack);
733+
}
754734
return BlockItrTy();
755735
}
756736

@@ -759,11 +739,7 @@ class OpEvalImpl {
759739
heavy::Value CurrentResult = getValue(Op.getArgs().front());
760740
heavy::Value Result = Context.CreatePair(CurrentResult,
761741
getValue(ExpandPacksOp.getResult()));
762-
763-
// Set the result in the parent scope.
764-
pop_scope();
765-
setValue(ExpandPacksOp.getResult(), Result);
766-
742+
setValueParentScope(ExpandPacksOp.getResult(), Result);
767743
return BlockItrTy();
768744
}
769745

@@ -829,6 +805,10 @@ class OpEvalImpl {
829805
}
830806
};
831807

808+
void OpEvalDeleter::operator()(OpEvalImpl* Ptr) const {
809+
delete Ptr;
810+
}
811+
832812
namespace base {
833813
void op_eval(Context& C, ValueRefs Args) {
834814
if (Args.size() != 1) {
@@ -839,16 +819,9 @@ void op_eval(Context& C, ValueRefs Args) {
839819
return C.RaiseError("expecting operation");
840820
}
841821

842-
// TODO Make the module own the OpEval instance. (maybe?)
843822
if (!C.OpEval) {
844-
C.OpEval = new OpEvalImpl(C);
845-
C.PushModuleCleanup(HEAVY_BASE_LIB_STR,
846-
C.CreateLambda([](Context& C, ValueRefs) {
847-
llvm_unreachable("FIXME find out why we do not get here");
848-
assert(C.OpEval && "OpEval should be set.");
849-
// Cleanup the OpEval object.
850-
delete C.OpEval;
851-
}, CaptureList{}));
823+
// Still easier than extern template specialization of std::default_delete
824+
C.OpEval = OpEvalPtr(new OpEvalImpl(C), OpEvalDeleter());
852825
}
853826

854827
C.OpEval->Eval(Op);

0 commit comments

Comments
 (0)