Skip to content

Commit b6e58ab

Browse files
committed
[Heavy] Add type checking; Check for procedures before function application
1 parent a44e089 commit b6e58ab

File tree

6 files changed

+63
-2
lines changed

6 files changed

+63
-2
lines changed

heavy/include/heavy/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ Heavy##NAME##Ty, mlir::Type, mlir::TypeStorage> { \
125125
} \
126126

127127
HEAVY_TYPE(Pair, "heavy.pair", "pair");
128+
HEAVY_TYPE(Procedure, "heavy.procedure", "procedure");
128129

129130
HEAVY_TYPE(Syntax, "heavy.syntax", "syntax");
130131
HEAVY_TYPE(OpGen, "heavy.opgen", "opgen");

heavy/include/heavy/OpGen.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ class OpGen : public ValueVisitor<OpGen, mlir::Value> {
184184

185185

186186
mlir::ModuleOp getModuleOp();
187+
188+
mlir::Value CheckType(heavy::SourceLocation Loc, mlir::Value V,
189+
mlir::Type Type);
187190

188191
// GetSingleResult
189192
// - visits a node expecting a single result

heavy/include/heavy/Ops.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def HeavyValueRefs : HeavyValueBase<"ValueRefs">; // For arguments array.
4545
def HeavyRest : HeavyValueBase<"Rest">; // For rest arguments.
4646
def HeavyPair : HeavyValueBase<"Pair">;
4747
def HeavySyntax : HeavyValueBase<"Syntax">;
48+
def HeavyProcedure : HeavyValueBase<"Procedure">;
49+
50+
def HeavyType : AnyTypeOf<[HeavyValue,
51+
HeavyPair,
52+
HeavyProcedure]>;
4853

4954
// If !heavy.value_refs is a function argument, the arguments should
5055
// only be (!heavy.context, !heavy.value_refs).
@@ -67,7 +72,7 @@ def heavy_ApplyOp : HeavyOp<"apply", [Terminator]> {
6772
of the current call arguments with arguments in that position.
6873
}];
6974

70-
let arguments = (ins HeavyValue:$fn,
75+
let arguments = (ins AnyTypeOf<[HeavyValue, HeavyProcedure]>:$fn,
7176
Variadic<AnyTypeOf<[HeavyValue, HeavyRest, HeavyValueRefs]>>:$args);
7277
let results = (outs);
7378
}
@@ -437,6 +442,17 @@ def heavy_MatchArgsOp : HeavyOp<"match_args"> {
437442
];
438443
}
439444

445+
def heavy_MatchTypeOp : HeavyOp<"match_type", []> {
446+
let description = [{
447+
Perform dynamic cast resulting in a value with a specified type.
448+
Like other Match operations this will raise an error unless it
449+
appears in the body of a PatternOp.
450+
}];
451+
452+
let arguments = (ins HeavyValue:$arg);
453+
let results = (outs HeavyType:$result);
454+
}
455+
440456
def heavy_SubpatternOp : HeavyOp<"subpattern", []> {
441457
let description = [{
442458
Match a list of elements over a subpattern that

heavy/lib/Dialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Dialect::Dialect(mlir::MLIRContext* Ctx)
3131
addTypes<HeavyValueRefsTy>();
3232
addTypes<HeavyRestTy>();
3333
addTypes<HeavyPairTy>();
34+
addTypes<HeavyProcedureTy>();
3435

3536
addTypes<HeavySyntaxTy>();
3637
addTypes<HeavyOpGenTy>();
@@ -97,6 +98,8 @@ void Dialect::printType(mlir::Type Type,
9798
Name = "rest";
9899
} else if (mlir::isa<HeavyPairTy>(Type)) {
99100
Name = "pair";
101+
} else if (mlir::isa<HeavyProcedureTy>(Type)) {
102+
Name = "procedure";
100103
} else if (mlir::isa<HeavySyntaxTy>(Type)) {
101104
Name = "syntax";
102105
} else if (mlir::isa<HeavyOpGenTy>(Type)) {

heavy/lib/OpEval.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class OpEvalImpl {
198198
else if (isa<MatchPairOp>(Op)) return Visit(cast<MatchPairOp>(Op));
199199
else if (isa<MatchTailOp>(Op)) return Visit(cast<MatchTailOp>(Op));
200200
else if (isa<MatchArgsOp>(Op)) return Visit(cast<MatchArgsOp>(Op));
201+
else if (isa<MatchTypeOp>(Op)) return Visit(cast<MatchTypeOp>(Op));
201202
else if (isa<SubpatternOp>(Op)) return Visit(cast<SubpatternOp>(Op));
202203
else if (isa<ExpandPacksOp>(Op)) return Visit(cast<ExpandPacksOp>(Op));
203204
else if (isa<ResolveOp>(Op)) return Visit(cast<ResolveOp>(Op));
@@ -656,7 +657,7 @@ class OpEvalImpl {
656657
BlockItrTy patternFail(mlir::Operation* Op, llvm::StringRef ErrMsg,
657658
llvm::ArrayRef<heavy::Value> Irr) {
658659
assert((isa<MatchOp, MatchPairOp, MatchTailOp, MatchArgsOp,
659-
MatchIdOp, SubpatternOp>(Op)) &&
660+
MatchTypeOp, MatchIdOp, SubpatternOp>(Op)) &&
660661
"Operation must be a pattern matcher");
661662

662663
mlir::Operation* ParentOp = Op->getParentOp();
@@ -775,6 +776,34 @@ class OpEvalImpl {
775776
return next(Op);
776777
}
777778

779+
BlockItrTy Visit(MatchTypeOp Op) {
780+
mlir::Type Type = Op.getResult().getType();
781+
heavy::Value Arg = getValue(Op.getArg());
782+
bool Result;
783+
switch (Arg.getKind()) {
784+
case ValueKind::Lambda:
785+
case ValueKind::Builtin:
786+
Result = isa<HeavyProcedureTy>(Type);
787+
break;
788+
case ValueKind::Pair:
789+
Result = isa<HeavyPairTy>(Type);
790+
break;
791+
default:
792+
Result = false;
793+
}
794+
795+
if (!Result) {
796+
std::string Str;
797+
llvm::raw_string_ostream Stream(Str);
798+
Type.print(Stream);
799+
Symbol* TypeStr = Context.CreateSymbol(Str);
800+
return patternFail(Op, "expecting object with type {}: {}",
801+
{TypeStr, Arg});
802+
}
803+
setValue(Op.getResult(), Arg);
804+
return next(Op);
805+
}
806+
778807
BlockItrTy Visit(ExpandPacksOp Op) {
779808
heavy::SourceLocation Loc = getSourceLocation(Op.getLoc());
780809
setValue(Op.getResult(), getValue(Op.getCdr()));

heavy/lib/OpGen.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ mlir::Value OpGen::GetSingleResult(heavy::Value V) {
185185
return LocalizeValue(Result);
186186
}
187187

188+
// Insert a check to assert the type of a value.
189+
mlir::Value OpGen::CheckType(heavy::SourceLocation Loc, mlir::Value V,
190+
mlir::Type Type) {
191+
return create<MatchTypeOp>(Loc, Type, V);
192+
}
193+
188194
// WithLibraryEnv - Call a thunk within the library environment
189195
// for <library spec>. (ie begin, import, export)
190196
void OpGen::WithLibraryEnv(Value Thunk) {
@@ -1259,6 +1265,9 @@ mlir::Value OpGen::HandleCall(Pair* P, heavy::EnvEntry FnEnvEntry) {
12591265
else
12601266
Fn = GetSingleResult(P->Car);
12611267

1268+
// Insert the check that Fn is a procedure.
1269+
Fn = CheckType(Loc, Fn, Builder.getType<HeavyProcedureTy>());
1270+
12621271
if (CheckError())
12631272
return Error();
12641273

0 commit comments

Comments
 (0)