Skip to content

Commit 36fe13b

Browse files
committed
[Heavy] Support multivariate continuations. Fix tail calls.
1 parent a2df047 commit 36fe13b

File tree

9 files changed

+213
-81
lines changed

9 files changed

+213
-81
lines changed

heavy/include/heavy/ContinuationStack.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ContinuationStack {
135135
DidCallContinuation = true; // debug mode only
136136
if (isa<Undefined>(Callee))
137137
return getDerived().RaiseError("callee is undefined");
138-
if (Args.data() != ApplyArgs.data()) {
138+
if (Args.data() != ValueRefs(ApplyArgs).drop_front().data()) {
139139
ApplyArgs.resize(Args.size() + 1);
140140
std::copy(Args.begin(), Args.end(), ApplyArgs.begin() + 1);
141141
}
@@ -194,7 +194,10 @@ class ContinuationStack {
194194
assert(ApplyArgs[0] && "callee must not be null");
195195
return ApplyArgs[0];
196196
}
197-
ValueRefs getCaptures() {
197+
heavy::ValueRefs getCallArgs() {
198+
return heavy::ValueRefs(ApplyArgs).drop_front();
199+
}
200+
heavy::ValueRefs getCaptures() {
198201
return cast<Lambda>(ApplyArgs[0])->getCaptures();
199202
}
200203
heavy::Value getCapture(unsigned I) {
@@ -243,7 +246,7 @@ class ContinuationStack {
243246
// Debug mode only
244247
DidCallContinuation = false;
245248

246-
ValueRefs Args = ValueRefs(ApplyArgs).drop_front();
249+
ValueRefs Args = getCallArgs();
247250
switch (Callee.getKind()) {
248251
case ValueKind::Lambda: {
249252
Lambda* L = cast<Lambda>(Callee);
@@ -302,10 +305,6 @@ class ContinuationStack {
302305
void PushCont(heavy::Value Callable) {
303306
PushCont([](Derived& Context, ValueRefs Args) {
304307
heavy::Value Callable = Context.getCapture(0);
305-
// FIXME Remove this drop_front when is certain
306-
// that it was vestigal
307-
// (we used to include the callee in Args)
308-
//Context.Apply(Callable, Args.drop_front());
309308
Context.Apply(Callable, Args);
310309
}, Callable);
311310
}
@@ -381,21 +380,21 @@ class ContinuationStack {
381380
C.Apply(InputProc, Proc);
382381
}
383382

384-
// SaveEscapeProc - Push Proc as current continuation for use
383+
// SaveEscapeProc - Push Proc as current continuation for use
385384
// as an escape procedure bound to Var, then
386385
// call the continuation we had before so
387386
// that Proc is only called explicitly.
388387
template <typename F>
389388
void SaveEscapeProc(Value Var, F Proc, CaptureList Captures) {
390-
assert(isa<Binding>(Var) && "expecting a binding for Var");
389+
assert(isa<Binding>(Var) && "expecting a binding for Var");
391390
Derived& C = getDerived();
392391
PushCont(Proc, Captures);
393392
CallCC(C.CreateLambda([](Derived& C, ValueRefs Args) {
394393
cast<Binding>(C.getCapture(0))->setValue(Args[0]);
395394
// Remove Proc from the stack before continuing.
396395
C.PopCont();
397396
C.Cont();
398-
}, CaptureList{Var}));
397+
}, CaptureList{Var}));
399398
}
400399

401400
void DynamicWind(Value Before, Value Thunk, Value After) {
@@ -422,7 +421,7 @@ class ContinuationStack {
422421
C.PushCont([](Derived& C, ValueRefs) {
423422
ValueRefs ThunkResults = C.getCaptures();
424423
C.Cont(ThunkResults);
425-
}, /*Captures=*/ThunkResults);
424+
}, /*Captures=*/ThunkResults);
426425
C.CurDW = PrevDW;
427426
C.Apply(After, {});
428427
}, CaptureList{After, PrevDW});

heavy/include/heavy/Dialect.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ struct Dialect : public mlir::Dialect {
5151
void printType(mlir::Type, mlir::DialectAsmPrinter&) const override;
5252
};
5353

54+
struct HeavyContextTy : public mlir::Type::TypeBase<
55+
HeavyContextTy,
56+
mlir::Type,
57+
mlir::TypeStorage> {
58+
static constexpr llvm::StringLiteral name = "heavy.context";
59+
static constexpr llvm::StringRef getMnemonic() { return "context"; }
60+
using Base::Base;
61+
};
62+
5463
struct HeavyValueTy : public mlir::Type::TypeBase<
5564
HeavyValueTy,
5665
mlir::Type,
@@ -60,6 +69,19 @@ struct HeavyValueTy : public mlir::Type::TypeBase<
6069
using Base::Base;
6170
};
6271

72+
// Represent a variadiac array of arguments for use with
73+
// continuation arguments.
74+
struct HeavyValueRefsTy : public mlir::Type::TypeBase<
75+
HeavyValueRefsTy,
76+
mlir::Type,
77+
mlir::TypeStorage> {
78+
static constexpr llvm::StringLiteral name = "heavy.value_refs";
79+
static constexpr llvm::StringRef getMnemonic() { return "value_refs"; }
80+
using Base::Base;
81+
};
82+
83+
// Represent a variadic list of arguments which are provided
84+
// as a scheme linked list.
6385
struct HeavyRestTy : public mlir::Type::TypeBase<
6486
HeavyRestTy,
6587
mlir::Type,

heavy/include/heavy/OpGen.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
257257
return heavy::Value(OpResult.getOwner());
258258
}
259259
if (auto BlockArg = mlir::dyn_cast<mlir::BlockArgument>(V)) {
260+
// FIXME Do we even get here? Is ContArg ever needed?
260261
mlir::Block* B = BlockArg.getOwner();
261262
return heavy::Value(reinterpret_cast<heavy::ContArg*>(B));
262263
}
@@ -276,8 +277,14 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
276277
Value Else);
277278
mlir::Value createContinuation(mlir::Operation* CallOp);
278279

280+
enum class RestParamKind {
281+
None = 0,
282+
List,
283+
ValueRefs,
284+
};
285+
279286
mlir::FunctionType createFunctionType(unsigned Arity,
280-
bool HasRestParam);
287+
RestParamKind RPK);
281288
heavy::FuncOp createFunction(SourceLocation Loc,
282289
llvm::StringRef MangledName,
283290
mlir::FunctionType FT);

heavy/include/heavy/Ops.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def HeavyValueAttr : AttrDef<HeavyDialect, "HeavyValue"> {
4040
}
4141

4242
// Additional Types
43-
def HeavyRest : HeavyValueBase<"Pair">; // For rest arguments.
43+
def HeavyContext : HeavyValueBase<"Context">;
44+
def HeavyValueRefs : HeavyValueBase<"ValueRefs">; // For arguments array.
45+
def HeavyRest : HeavyValueBase<"Rest">; // For rest arguments.
4446
def HeavyPair : HeavyValueBase<"Pair">;
4547
def HeavySyntax : HeavyValueBase<"Syntax">;
4648

@@ -59,7 +61,8 @@ def heavy_ApplyOp : HeavyOp<"apply", [ReturnLike, Terminator]> {
5961
a continuation.
6062
}];
6163

62-
let arguments = (ins HeavyValue:$fn, Variadic<HeavyValue>:$args);
64+
let arguments = (ins HeavyValue:$fn,
65+
Variadic<AnyTypeOf<[HeavyValue, HeavyValueRefs]>>:$args);
6366
let results = (outs);
6467
}
6568

@@ -266,13 +269,16 @@ def heavy_LiteralOp : HeavyOp<"literal", [Pure]> {
266269
];
267270
}
268271

269-
def heavy_LoadClosureOp : HeavyOp<"load_closure"> {
270-
let summary = "load_closure";
272+
def heavy_LoadRefOp : HeavyOp<"load_ref"> {
273+
let summary = "load_ref";
271274
let description = [{
272-
Read an element from lambda's capture list.
275+
Read an element from a ValueRefs object which can be
276+
the context's callee capture list or a variadic argument
277+
list which is primarily used for continuations.
273278
}];
274279

275-
let arguments = (ins HeavyValue:$closure, UI32Attr:$index);
280+
let arguments = (ins AnyTypeOf<[HeavyContext, HeavyValueRefs]>:$value_refs,
281+
UI32Attr:$index);
276282
let results = (outs HeavyValue:$result);
277283

278284
let builders = [

heavy/lib/Builtins.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ heavy::ExternFunction eq;
6767
heavy::ExternFunction equal;
6868
heavy::ExternFunction eqv;
6969
heavy::ExternFunction call_cc;
70+
heavy::ExternFunction values;
71+
heavy::ExternFunction call_with_values;
7072
heavy::ExternFunction with_exception_handler;
7173
heavy::ExternFunction raise;
7274
heavy::ExternFunction error;
@@ -385,6 +387,19 @@ void call_cc(Context& C, ValueRefs Args) {
385387
C.CallCC(Args[0]);
386388
}
387389

390+
void values(Context& C, ValueRefs Args) {
391+
C.Cont(Args);
392+
}
393+
394+
void call_with_values(Context& C, ValueRefs Args) {
395+
if (Args.size() != 2)
396+
return C.RaiseError("invalid arity");
397+
heavy::Value Producer = Args[0];
398+
heavy::Value Consumer = Args[1];
399+
C.PushCont(Consumer);
400+
C.Apply(Producer, {});
401+
}
402+
388403
void dump(Context& C, ValueRefs Args) {
389404
if (Args.size() != 1) return C.RaiseError("invalid arity");
390405
Args[0].dump();
@@ -888,6 +903,8 @@ void HEAVY_BASE_INIT(heavy::Context& Context) {
888903
HEAVY_BASE_VAR(equal) = heavy::base::equal;
889904
HEAVY_BASE_VAR(eqv) = heavy::base::eqv;
890905
HEAVY_BASE_VAR(call_cc) = heavy::base::call_cc;
906+
HEAVY_BASE_VAR(values) = heavy::base::values;
907+
HEAVY_BASE_VAR(call_with_values) = heavy::base::call_with_values;
891908
HEAVY_BASE_VAR(with_exception_handler)
892909
= heavy::base::with_exception_handler;
893910
HEAVY_BASE_VAR(raise) = heavy::base::raise;
@@ -965,6 +982,8 @@ void HEAVY_BASE_LOAD_MODULE(heavy::Context& Context) {
965982
{"equal?", HEAVY_BASE_VAR(equal)},
966983
{"eqv?", HEAVY_BASE_VAR(eqv)},
967984
{"call/cc", HEAVY_BASE_VAR(call_cc)},
985+
{"values", HEAVY_BASE_VAR(values)},
986+
{"call-with-values", HEAVY_BASE_VAR(call_with_values)},
968987
{"with-exception-handler", HEAVY_BASE_VAR(with_exception_handler)},
969988
{"raise", HEAVY_BASE_VAR(raise)},
970989
{"error", HEAVY_BASE_VAR(error)},

heavy/lib/Dialect.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Dialect::Dialect(mlir::MLIRContext* Ctx)
2727
addTypes<HeavyValueTy>();
2828
addAttributes<HeavyValueAttr>();
2929

30+
addTypes<HeavyContextTy>();
31+
addTypes<HeavyValueRefsTy>();
3032
addTypes<HeavyRestTy>();
3133
addTypes<HeavyPairTy>();
3234

@@ -58,6 +60,12 @@ mlir::Type Dialect::parseType(mlir::DialectAsmParser& P) const {
5860
mlir::Builder B = P.getBuilder();
5961
if (Name == HeavyValueTy::getMnemonic())
6062
return B.getType<HeavyValueTy>();
63+
if (Name == HeavyContextTy::getMnemonic())
64+
return B.getType<HeavyContextTy>();
65+
if (Name == HeavyValueRefsTy::getMnemonic())
66+
return B.getType<HeavyValueRefsTy>();
67+
if (Name == HeavyRestTy::getMnemonic())
68+
return B.getType<HeavyRestTy>();
6169
if (Name == HeavySyntaxTy::getMnemonic())
6270
return B.getType<HeavySyntaxTy>();
6371
if (Name == HeavyPairTy::getMnemonic())
@@ -81,6 +89,10 @@ void Dialect::printType(mlir::Type Type,
8189
char const* Name;
8290
if (mlir::isa<HeavyValueTy>(Type)) {
8391
Name = "value";
92+
} else if (mlir::isa<HeavyContextTy>(Type)) {
93+
Name = "context";
94+
} else if (mlir::isa<HeavyValueRefsTy>(Type)) {
95+
Name = "value_refs";
8496
} else if (mlir::isa<HeavyRestTy>(Type)) {
8597
Name = "rest";
8698
} else if (mlir::isa<HeavyPairTy>(Type)) {
@@ -147,10 +159,10 @@ void LiteralOp::build(mlir::OpBuilder& B, mlir::OperationState& OpState,
147159
HeavyValueAttr::get(B.getContext(), V));
148160
}
149161

150-
void LoadClosureOp::build(mlir::OpBuilder& B, mlir::OperationState& OpState,
151-
mlir::Value Closure, uint32_t Index) {
162+
void LoadRefOp::build(mlir::OpBuilder& B, mlir::OperationState& OpState,
163+
mlir::Value ValueRefs, uint32_t Index) {
152164
mlir::Type HeavyValueT = B.getType<HeavyValueTy>();
153-
LoadClosureOp::build(B, OpState, HeavyValueT, Closure,
165+
LoadRefOp::build(B, OpState, HeavyValueT, ValueRefs,
154166
B.getUI32IntegerAttr(Index));
155167
}
156168

0 commit comments

Comments
 (0)