From fe964e2026448f9de12f348c0377f01e7c991214 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 22 Jan 2025 15:41:40 -0500 Subject: [PATCH 01/88] Another IR, maybe for real this time --- ir.py | 241 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 ir.py diff --git a/ir.py b/ir.py new file mode 100644 index 00000000..3cc91ac2 --- /dev/null +++ b/ir.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +from __future__ import annotations +import dataclasses +import io +import itertools +import json +import os +import typing +import unittest + +from typing import Dict, Optional, Tuple + +from scrapscript import ( + Access, + Apply, + Assign, + Binop, + BinopKind, + Function, + Hole, + Int, + List, + MatchFunction, + Object, + Record, + Spread, + String, + Var, + Variant, + Where, + free_in, + type_of, + IntType, + StringType, + parse, + tokenize, +) + +@dataclasses.dataclass +class Instr: + pass + +@dataclasses.dataclass +class Const(Instr): + value: Object + +@dataclasses.dataclass +class Param(Instr): + idx: int + name: str + +@dataclasses.dataclass +class MatchFail(Instr): + pass + +@dataclasses.dataclass +class HasOperands(Instr): + operands: list[Instr] = dataclasses.field(init=False, default_factory=list) + + def __init__(self, *operands: Instr) -> None: + self.operands = list(operands) + +@dataclasses.dataclass(init=False) +class IntAdd(HasOperands): + pass + +@dataclasses.dataclass(init=False) +class IntLess(HasOperands): + pass + +@dataclasses.dataclass(init=False) +class IsNumEqualWord(HasOperands): + expected: int + + def __init__(self, value: Instr, expected: int) -> None: + self.operands = [value] + self.expected = expected + +@dataclasses.dataclass +class Control(Instr): + pass + +Env = Dict[str, Instr] + +@dataclasses.dataclass +class Block: + id: int + instrs: list[Instr] = dataclasses.field(init=False, default_factory=list) + + def append(self, instr: Instr) -> None: + self.instrs.append(instr) + +@dataclasses.dataclass +class Jump(Control): + target: Block + +@dataclasses.dataclass(init=False) +class Return(HasOperands, Control): + pass + +@dataclasses.dataclass(init=False) +class CondBranch(Control, HasOperands): + conseq: Block + alt: Block + + def __init__(self, cond: Instr, conseq: Block, alt: Block) -> None: + self.conseq = conseq + self.alt = alt + self.operands = [cond] + +@dataclasses.dataclass +class CFG: + blocks: list[Block] = dataclasses.field(init=False, default_factory=list) + entry: Block = dataclasses.field(init=False) + + def __init__(self) -> None: + self.blocks = [] + self.entry = self.new_block() + + def new_block(self) -> Block: + result = Block(len(self.blocks)) + self.blocks.append(result) + return result + +@dataclasses.dataclass +class Function(Instr): + params: list[str] + cfg: CFG = dataclasses.field(init=False, default_factory=CFG) + + # def initial_env(self) -> Env: + # result = {} + # for idx, name in enumerate(self.params): + # instr = Param(idx, name) + # result[name] = self.cfg.emit(Param(idx, name)) + # return result + +class Compiler: + def __init__(self, entry: Function) -> None: + self.gensym_counter: int = 0 + self.fn: Function = entry + self.block: Block = entry.cfg.entry + self.fns: list[Function] = [entry] + + def gensym(self, stem: str = "tmp") -> str: + self.gensym_counter += 1 + return f"{stem}_{self.gensym_counter-1}" + + def push_fn(self, fn: Function) -> Function: + self.fns.append(fn) + prev_fn = self.fn + self.restore_fn(fn) + return prev_fn + + def restore_fn(self, fn: Function) -> None: + self.fn = fn + self.block = fn.cfg.entry + + def emit(self, instr: Instr) -> Instr: + self.block.append(instr) + return instr + + def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success: Block, fallthrough: Block) -> Env: + if isinstance(pattern, Int): + cond = self.emit(IsNumEqualWord(param, pattern.value)) + self.emit(CondBranch(cond, success, fallthrough)) + return {} + raise NotImplementedError(f"pattern {type(pattern)} {pattern}") + + def compile(self, env: Env, exp: Object) -> Instr: + if isinstance(exp, Int): + return self.emit(Const(exp)) + if isinstance(exp, Binop): + left = self.compile(env, exp.left) + right = self.compile(env, exp.right) + if exp.op == BinopKind.ADD: + return IntAdd(left, right) + if exp.op == BinopKind.LESS: + return IntLess(left, right) + if isinstance(exp, MatchFunction): + param = self.gensym("arg") + fn = Function([param]) + prev_fn = self.push_fn(fn) + self.block = fn.cfg.entry + # + funcenv = {} + for idx, name in enumerate(fn.params): + funcenv[name] = self.emit(Param(idx, name)) + no_match = self.fn.cfg.new_block() + no_match.append(MatchFail()) + case_blocks = [self.fn.cfg.new_block() for case in exp.cases] + case_blocks.append(no_match) + self.emit(Jump(case_blocks[0])) + for i, case in enumerate(exp.cases): + self.block = case_blocks[i] + fallthrough = case_blocks[i+1] + body_block = self.fn.cfg.new_block() + env_updates = self.compile_match_pattern(funcenv, funcenv[param], case.pattern, body_block, fallthrough) + self.block = body_block + case_result = self.compile({**funcenv, **env_updates}, case.body) + self.emit(Return(case_result)) + # + self.restore_fn(prev_fn) + return fn + raise NotImplementedError(f"exp {type(exp)} {exp}") + + +class IRTests(unittest.TestCase): + def _parse(self, source: str) -> Object: + return parse(tokenize(source)) + + def test_int(self) -> None: + compiler = Compiler(Function([])) + result = compiler.compile({}, Int(1)) + self.assertEqual(result, Const(Int(1))) + + def test_add_int(self) -> None: + compiler = Compiler(Function([])) + result = compiler.compile({}, self._parse("1 + 2")) + self.assertEqual(result, IntAdd(Const(Int(1)), Const(Int(2)))) + + def test_less_int(self) -> None: + compiler = Compiler(Function([])) + result = compiler.compile({}, self._parse("1 < 2")) + self.assertEqual(result, IntLess(Const(Int(1)), Const(Int(2)))) + + # def test_match_no_cases(self) -> None: + # compiler = Compiler() + # result = compiler.compile({}, MatchFunction([])) + # self.assertEqual(result, IntLess(Const(Int(1)), Const(Int(2)))) + + def test_match_one_case(self) -> None: + compiler = Compiler(Function([])) + result = compiler.compile({}, self._parse("| 1 -> 2")) + self.assertIsInstance(result, Function) + self.assertEqual(result.cfg.entry.instrs, [ + Param(0, "arg_0") + ]) + +if __name__ == "__main__": + __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 + unittest.main() From 6214f3b38f7142ea7e50ff056f88c66daf43fe0a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 22 Jan 2025 17:12:44 -0500 Subject: [PATCH 02/88] Pretty printing --- ir.py | 152 +++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 39 deletions(-) diff --git a/ir.py b/ir.py index 3cc91ac2..c673d2c0 100644 --- a/ir.py +++ b/ir.py @@ -36,39 +36,79 @@ tokenize, ) + @dataclasses.dataclass +class InstrId: + data: dict[Instr, int] = dataclasses.field(default_factory=dict) + + def __getitem__(self, instr: Instr) -> int: + id = self.data.get(instr) + if id is not None: + return id + id = len(self.data) + self.data[instr] = id + return id + + +@dataclasses.dataclass(eq=False) class Instr: - pass + def __hash__(self) -> int: + return id(self) -@dataclasses.dataclass + def __eq__(self, other: object) -> bool: + return self is other + + def to_string(self, gvn: InstrId) -> str: + return type(self).__name__ + + +@dataclasses.dataclass(eq=False) class Const(Instr): value: Object -@dataclasses.dataclass + def to_string(self, gvn: InstrId) -> str: + return f"{type(self).__name__}<{self.value}>" + + +@dataclasses.dataclass(eq=False) class Param(Instr): idx: int name: str -@dataclasses.dataclass + def to_string(self, gvn: InstrId) -> str: + return f"{type(self).__name__}<{self.idx}; {self.name}>" + + +@dataclasses.dataclass(eq=False) class MatchFail(Instr): pass -@dataclasses.dataclass + +@dataclasses.dataclass(eq=False) class HasOperands(Instr): operands: list[Instr] = dataclasses.field(init=False, default_factory=list) def __init__(self, *operands: Instr) -> None: self.operands = list(operands) -@dataclasses.dataclass(init=False) + def to_string(self, gvn: InstrId) -> str: + stem = f"{type(self).__name__}" + if not self.operands: + return stem + return stem + " " + ", ".join(f"v{gvn[op]}" for op in self.operands) + + +@dataclasses.dataclass(init=False, eq=False) class IntAdd(HasOperands): pass -@dataclasses.dataclass(init=False) + +@dataclasses.dataclass(init=False, eq=False) class IntLess(HasOperands): pass -@dataclasses.dataclass(init=False) + +@dataclasses.dataclass(init=False, eq=False) class IsNumEqualWord(HasOperands): expected: int @@ -76,13 +116,19 @@ def __init__(self, value: Instr, expected: int) -> None: self.operands = [value] self.expected = expected -@dataclasses.dataclass + def to_string(self, gvn: InstrId) -> str: + return super().to_string(gvn) + f", {self.expected}" + + +@dataclasses.dataclass(eq=False) class Control(Instr): pass + Env = Dict[str, Instr] -@dataclasses.dataclass + +@dataclasses.dataclass(eq=False) class Block: id: int instrs: list[Instr] = dataclasses.field(init=False, default_factory=list) @@ -90,15 +136,24 @@ class Block: def append(self, instr: Instr) -> None: self.instrs.append(instr) -@dataclasses.dataclass + def name(self) -> str: + return f"bb{self.id}" + + +@dataclasses.dataclass(eq=False) class Jump(Control): target: Block -@dataclasses.dataclass(init=False) + def to_string(self, gvn: InstrId) -> str: + return super().to_string(gvn) + f" {self.target.name()}" + + +@dataclasses.dataclass(init=False, eq=False) class Return(HasOperands, Control): pass -@dataclasses.dataclass(init=False) + +@dataclasses.dataclass(init=False, eq=False) class CondBranch(Control, HasOperands): conseq: Block alt: Block @@ -108,6 +163,10 @@ def __init__(self, cond: Instr, conseq: Block, alt: Block) -> None: self.alt = alt self.operands = [cond] + def to_string(self, gvn: InstrId) -> str: + return super().to_string(gvn) + f", {self.conseq.name()}, {self.alt.name()}" + + @dataclasses.dataclass class CFG: blocks: list[Block] = dataclasses.field(init=False, default_factory=list) @@ -122,36 +181,48 @@ def new_block(self) -> Block: self.blocks.append(result) return result -@dataclasses.dataclass -class Function(Instr): + def to_string(self, fn: IRFunction, gvn: InstrId) -> str: + result = "" + for block in self.blocks: + result += f" {block.name()} {{\n" + for instr in block.instrs: + if isinstance(instr, Control): + result += f" {instr.to_string(gvn)}\n" + else: + result += f" v{gvn[instr]} = {instr.to_string(gvn)}\n" + result += " }\n" + return result + + +@dataclasses.dataclass(eq=False) +class IRFunction(Instr): params: list[str] cfg: CFG = dataclasses.field(init=False, default_factory=CFG) - # def initial_env(self) -> Env: - # result = {} - # for idx, name in enumerate(self.params): - # instr = Param(idx, name) - # result[name] = self.cfg.emit(Param(idx, name)) - # return result + def to_string(self, gvn: InstrId) -> str: + result = f"fn {gvn[self]} {{\n" + result += self.cfg.to_string(self, gvn) + return result + "}" + class Compiler: - def __init__(self, entry: Function) -> None: + def __init__(self, entry: IRFunction) -> None: self.gensym_counter: int = 0 - self.fn: Function = entry + self.fn: IRFunction = entry self.block: Block = entry.cfg.entry - self.fns: list[Function] = [entry] + self.fns: list[IRFunction] = [entry] def gensym(self, stem: str = "tmp") -> str: self.gensym_counter += 1 return f"{stem}_{self.gensym_counter-1}" - def push_fn(self, fn: Function) -> Function: + def push_fn(self, fn: IRFunction) -> IRFunction: self.fns.append(fn) prev_fn = self.fn self.restore_fn(fn) return prev_fn - def restore_fn(self, fn: Function) -> None: + def restore_fn(self, fn: IRFunction) -> None: self.fn = fn self.block = fn.cfg.entry @@ -173,12 +244,12 @@ def compile(self, env: Env, exp: Object) -> Instr: left = self.compile(env, exp.left) right = self.compile(env, exp.right) if exp.op == BinopKind.ADD: - return IntAdd(left, right) + return self.emit(IntAdd(left, right)) if exp.op == BinopKind.LESS: - return IntLess(left, right) + return self.emit(IntLess(left, right)) if isinstance(exp, MatchFunction): param = self.gensym("arg") - fn = Function([param]) + fn = IRFunction([param]) prev_fn = self.push_fn(fn) self.block = fn.cfg.entry # @@ -192,7 +263,7 @@ def compile(self, env: Env, exp: Object) -> Instr: self.emit(Jump(case_blocks[0])) for i, case in enumerate(exp.cases): self.block = case_blocks[i] - fallthrough = case_blocks[i+1] + fallthrough = case_blocks[i + 1] body_block = self.fn.cfg.new_block() env_updates = self.compile_match_pattern(funcenv, funcenv[param], case.pattern, body_block, fallthrough) self.block = body_block @@ -209,17 +280,17 @@ def _parse(self, source: str) -> Object: return parse(tokenize(source)) def test_int(self) -> None: - compiler = Compiler(Function([])) + compiler = Compiler(IRFunction([])) result = compiler.compile({}, Int(1)) self.assertEqual(result, Const(Int(1))) def test_add_int(self) -> None: - compiler = Compiler(Function([])) + compiler = Compiler(IRFunction([])) result = compiler.compile({}, self._parse("1 + 2")) self.assertEqual(result, IntAdd(Const(Int(1)), Const(Int(2)))) def test_less_int(self) -> None: - compiler = Compiler(Function([])) + compiler = Compiler(IRFunction([])) result = compiler.compile({}, self._parse("1 < 2")) self.assertEqual(result, IntLess(Const(Int(1)), Const(Int(2)))) @@ -229,12 +300,15 @@ def test_less_int(self) -> None: # self.assertEqual(result, IntLess(Const(Int(1)), Const(Int(2)))) def test_match_one_case(self) -> None: - compiler = Compiler(Function([])) - result = compiler.compile({}, self._parse("| 1 -> 2")) - self.assertIsInstance(result, Function) - self.assertEqual(result.cfg.entry.instrs, [ - Param(0, "arg_0") - ]) + compiler = Compiler(IRFunction([])) + result = compiler.compile({}, self._parse("| 1 -> 2 + 3")) + self.assertIsInstance(result, IRFunction) + gvn = InstrId() + self.assertEqual(result.to_string(gvn), "") + # self.assertEqual(result.cfg.entry.instrs, [ + # Param(0, "arg_0") + # ]) + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From cd64df3956df4825b35768f1f130c2498ceb5c24 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 22 Jan 2025 17:36:21 -0500 Subject: [PATCH 03/88] Fix tests --- ir.py | 150 ++++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 120 insertions(+), 30 deletions(-) diff --git a/ir.py b/ir.py index c673d2c0..d35e4770 100644 --- a/ir.py +++ b/ir.py @@ -125,6 +125,14 @@ class Control(Instr): pass +@dataclasses.dataclass(eq=False) +class NewClosure(Instr): + fn: IRFunction + + def to_string(self, gvn: InstrId) -> str: + return super().to_string(gvn) + f", {self.fn.name()}" + + Env = Dict[str, Instr] @@ -195,22 +203,32 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str: @dataclasses.dataclass(eq=False) -class IRFunction(Instr): +class IRFunction: + id: int params: list[str] cfg: CFG = dataclasses.field(init=False, default_factory=CFG) + def name(self) -> str: + return f"fn{self.id}" + def to_string(self, gvn: InstrId) -> str: - result = f"fn {gvn[self]} {{\n" + result = f"{self.name()} {{\n" result += self.cfg.to_string(self, gvn) return result + "}" class Compiler: - def __init__(self, entry: IRFunction) -> None: + def __init__(self) -> None: + self.fns: list[IRFunction] = [] + entry = self.new_function([]) self.gensym_counter: int = 0 self.fn: IRFunction = entry self.block: Block = entry.cfg.entry - self.fns: list[IRFunction] = [entry] + + def new_function(self, params: list[str]) -> IRFunction: + result = IRFunction(len(self.fns), params) + self.fns.append(result) + return result def gensym(self, stem: str = "tmp") -> str: self.gensym_counter += 1 @@ -237,6 +255,9 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success return {} raise NotImplementedError(f"pattern {type(pattern)} {pattern}") + def compile_body(self, env: Env, exp: Object) -> None: + self.emit(Return(self.compile(env, exp))) + def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, Int): return self.emit(Const(exp)) @@ -249,7 +270,7 @@ def compile(self, env: Env, exp: Object) -> Instr: return self.emit(IntLess(left, right)) if isinstance(exp, MatchFunction): param = self.gensym("arg") - fn = IRFunction([param]) + fn = self.new_function([param]) prev_fn = self.push_fn(fn) self.block = fn.cfg.entry # @@ -267,11 +288,10 @@ def compile(self, env: Env, exp: Object) -> Instr: body_block = self.fn.cfg.new_block() env_updates = self.compile_match_pattern(funcenv, funcenv[param], case.pattern, body_block, fallthrough) self.block = body_block - case_result = self.compile({**funcenv, **env_updates}, case.body) - self.emit(Return(case_result)) + self.compile_body({**funcenv, **env_updates}, case.body) # self.restore_fn(prev_fn) - return fn + return self.emit(NewClosure(fn)) raise NotImplementedError(f"exp {type(exp)} {exp}") @@ -280,34 +300,104 @@ def _parse(self, source: str) -> Object: return parse(tokenize(source)) def test_int(self) -> None: - compiler = Compiler(IRFunction([])) - result = compiler.compile({}, Int(1)) - self.assertEqual(result, Const(Int(1))) + compiler = Compiler() + compiler.compile_body({}, Int(1)) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + Return v0 + } +}""", + ) def test_add_int(self) -> None: - compiler = Compiler(IRFunction([])) - result = compiler.compile({}, self._parse("1 + 2")) - self.assertEqual(result, IntAdd(Const(Int(1)), Const(Int(2)))) + compiler = Compiler() + compiler.compile_body({}, self._parse("1 + 2")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + v1 = Const<2> + v2 = IntAdd v0, v1 + Return v2 + } +}""", + ) def test_less_int(self) -> None: - compiler = Compiler(IRFunction([])) - result = compiler.compile({}, self._parse("1 < 2")) - self.assertEqual(result, IntLess(Const(Int(1)), Const(Int(2)))) - - # def test_match_no_cases(self) -> None: - # compiler = Compiler() - # result = compiler.compile({}, MatchFunction([])) - # self.assertEqual(result, IntLess(Const(Int(1)), Const(Int(2)))) + compiler = Compiler() + compiler.compile_body({}, self._parse("1 < 2")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + v1 = Const<2> + v2 = IntLess v0, v1 + Return v2 + } +}""", + ) + + def test_match_no_cases(self) -> None: + compiler = Compiler() + compiler.compile_body({}, MatchFunction([])) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewClosure, fn1 + Return v0 + } +}""", + ) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; arg_0> + Jump bb1 + } + bb1 { + v1 = MatchFail + } +}""", + ) def test_match_one_case(self) -> None: - compiler = Compiler(IRFunction([])) - result = compiler.compile({}, self._parse("| 1 -> 2 + 3")) - self.assertIsInstance(result, IRFunction) - gvn = InstrId() - self.assertEqual(result.to_string(gvn), "") - # self.assertEqual(result.cfg.entry.instrs, [ - # Param(0, "arg_0") - # ]) + compiler = Compiler() + compiler.compile_body({}, self._parse("| 1 -> 2 + 3")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; arg_0> + Jump bb2 + } + bb1 { + v1 = MatchFail + } + bb2 { + v2 = IsNumEqualWord v0, 1 + CondBranch v2, bb3, bb1 + } + bb3 { + v3 = Const<2> + v4 = Const<3> + v5 = IntAdd v3, v4 + Return v5 + } +}""", + ) if __name__ == "__main__": From 1e6c8a3b69408728924324dbeab9019419782136 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 00:48:32 -0500 Subject: [PATCH 04/88] . --- ir.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/ir.py b/ir.py index d35e4770..cecc4e11 100644 --- a/ir.py +++ b/ir.py @@ -399,6 +399,39 @@ def test_match_one_case(self) -> None: }""", ) + def test_match_two_cases(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("| 1 -> 2 | 3 -> 4")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; arg_0> + Jump bb2 + } + bb1 { + v1 = MatchFail + } + bb2 { + v2 = IsNumEqualWord v0, 1 + CondBranch v2, bb4, bb3 + } + bb3 { + v3 = IsNumEqualWord v0, 3 + CondBranch v3, bb5, bb1 + } + bb4 { + v4 = Const<2> + Return v4 + } + bb5 { + v5 = Const<4> + Return v5 + } +}""", + ) + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 7c97f7c50fec54cfe4cb04caaef8d545b6577e51 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 00:54:07 -0500 Subject: [PATCH 05/88] Add var --- ir.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ir.py b/ir.py index cecc4e11..89508f0a 100644 --- a/ir.py +++ b/ir.py @@ -253,6 +253,9 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success cond = self.emit(IsNumEqualWord(param, pattern.value)) self.emit(CondBranch(cond, success, fallthrough)) return {} + if isinstance(pattern, Var): + self.emit(Jump(success)) + return {pattern.name: param} raise NotImplementedError(f"pattern {type(pattern)} {pattern}") def compile_body(self, env: Env, exp: Object) -> None: @@ -261,6 +264,8 @@ def compile_body(self, env: Env, exp: Object) -> None: def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, Int): return self.emit(Const(exp)) + if isinstance(exp, Var): + return env[exp.name] if isinstance(exp, Binop): left = self.compile(env, exp.left) right = self.compile(env, exp.right) @@ -432,6 +437,31 @@ def test_match_two_cases(self) -> None: }""", ) + def test_match_var(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("| a -> a + 1")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; arg_0> + Jump bb2 + } + bb1 { + v1 = MatchFail + } + bb2 { + Jump bb3 + } + bb3 { + v2 = Const<1> + v3 = IntAdd v0, v2 + Return v3 + } +}""", + ) + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 682c19435e8e8b887caa70590e9ef788a5b50af4 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 00:55:02 -0500 Subject: [PATCH 06/88] Add project section for uv --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 22b75785..94d70cfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,3 +52,7 @@ line-length = 120 [tool.ruff.lint] ignore = ["E741"] + +[project] +name = "scrapscript" +version = "0.1.1" From 4ae6d330a2aae00167b0e0f1bd3ee65b5da7096a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 01:01:46 -0500 Subject: [PATCH 07/88] . --- ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index 89508f0a..b00dbf3a 100644 --- a/ir.py +++ b/ir.py @@ -130,7 +130,7 @@ class NewClosure(Instr): fn: IRFunction def to_string(self, gvn: InstrId) -> str: - return super().to_string(gvn) + f", {self.fn.name()}" + return super().to_string(gvn) + f" {self.fn.name()}" Env = Dict[str, Instr] @@ -358,7 +358,7 @@ def test_match_no_cases(self) -> None: """\ fn0 { bb0 { - v0 = NewClosure, fn1 + v0 = NewClosure fn1 Return v0 } }""", From 34a528b542b50e2dfc5aa5f61f45930f321ce144 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 01:01:51 -0500 Subject: [PATCH 08/88] Support id function --- ir.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ir.py b/ir.py index b00dbf3a..4d706bf5 100644 --- a/ir.py +++ b/ir.py @@ -297,6 +297,20 @@ def compile(self, env: Env, exp: Object) -> Instr: # self.restore_fn(prev_fn) return self.emit(NewClosure(fn)) + if isinstance(exp, Function): + assert isinstance(exp.arg, Var) + param = exp.arg.name + fn = self.new_function([param]) + prev_fn = self.push_fn(fn) + self.block = fn.cfg.entry + # + funcenv = {} + for idx, name in enumerate(fn.params): + funcenv[name] = self.emit(Param(idx, name)) + self.compile_body(funcenv, exp.body) + # + self.restore_fn(prev_fn) + return self.emit(NewClosure(fn)) raise NotImplementedError(f"exp {type(exp)} {exp}") @@ -350,6 +364,28 @@ def test_less_int(self) -> None: }""", ) + def test_fun_id(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("a -> a")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewClosure fn1 + Return v0 + } +}""") + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; a> + Return v0 + } +}""") + def test_match_no_cases(self) -> None: compiler = Compiler() compiler.compile_body({}, MatchFunction([])) From 95291319d30841efd36ba741e5b649387571d181 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 01:05:03 -0500 Subject: [PATCH 09/88] Support let binding --- ir.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index 4d706bf5..f1166de3 100644 --- a/ir.py +++ b/ir.py @@ -273,6 +273,11 @@ def compile(self, env: Env, exp: Object) -> Instr: return self.emit(IntAdd(left, right)) if exp.op == BinopKind.LESS: return self.emit(IntLess(left, right)) + if isinstance(exp, Where): + assert isinstance(exp.binding, Assign) + name, value_exp, body_exp = exp.binding.name.name, exp.binding.value, exp.body + value = self.compile(env, value_exp) + return self.compile({**env, name: value}, body_exp) if isinstance(exp, MatchFunction): param = self.gensym("arg") fn = self.new_function([param]) @@ -364,6 +369,20 @@ def test_less_int(self) -> None: }""", ) + def test_let(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("a . a = 1")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + Return v0 + } +}""", + ) + def test_fun_id(self) -> None: compiler = Compiler() compiler.compile_body({}, self._parse("a -> a")) @@ -375,7 +394,8 @@ def test_fun_id(self) -> None: v0 = NewClosure fn1 Return v0 } -}""") +}""", + ) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -384,7 +404,8 @@ def test_fun_id(self) -> None: v0 = Param<0; a> Return v0 } -}""") +}""", + ) def test_match_no_cases(self) -> None: compiler = Compiler() From b2f4c6ca892c760af47788e683eac296f2058b73 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 01:19:50 -0500 Subject: [PATCH 10/88] . --- ir.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index f1166de3..cc9aac00 100644 --- a/ir.py +++ b/ir.py @@ -262,7 +262,7 @@ def compile_body(self, env: Env, exp: Object) -> None: self.emit(Return(self.compile(env, exp))) def compile(self, env: Env, exp: Object) -> Instr: - if isinstance(exp, Int): + if isinstance(exp, (Int, String)): return self.emit(Const(exp)) if isinstance(exp, Var): return env[exp.name] @@ -337,6 +337,20 @@ def test_int(self) -> None: }""", ) + def test_str(self) -> None: + compiler = Compiler() + compiler.compile_body({}, String("hello")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<"hello"> + Return v0 + } +}""", + ) + def test_add_int(self) -> None: compiler = Compiler() compiler.compile_body({}, self._parse("1 + 2")) From 0e5fa90e642b9f86e4a91dc818659220c740c5e3 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 01:46:08 -0500 Subject: [PATCH 11/88] Add closures/closureref --- ir.py | 149 ++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 114 insertions(+), 35 deletions(-) diff --git a/ir.py b/ir.py index cc9aac00..a8170534 100644 --- a/ir.py +++ b/ir.py @@ -120,17 +120,38 @@ def to_string(self, gvn: InstrId) -> str: return super().to_string(gvn) + f", {self.expected}" +@dataclasses.dataclass(eq=False) +class ClosureRef(HasOperands): + idx: int + name: str + + def __init__(self, closure: Instr, idx: int, name: str) -> None: + self.operands = [closure] + self.idx = idx + self.name = name + + def to_string(self, gvn: InstrId) -> str: + return f"{type(self).__name__}<{self.idx}; {self.name}> v{gvn[self.operands[0]]}" + + @dataclasses.dataclass(eq=False) class Control(Instr): pass @dataclasses.dataclass(eq=False) -class NewClosure(Instr): +class NewClosure(HasOperands): fn: IRFunction + def __init__(self, fn: IRFunction, bound: list[Instr]) -> None: + self.fn = fn + self.operands = bound.copy() + def to_string(self, gvn: InstrId) -> str: - return super().to_string(gvn) + f" {self.fn.name()}" + stem = f"{type(self).__name__}<{self.fn.name()}>" + if not self.operands: + return stem + return f"{stem} " + ", ".join(f"v{gvn[op]}" for op in self.operands) Env = Dict[str, Instr] @@ -235,7 +256,6 @@ def gensym(self, stem: str = "tmp") -> str: return f"{stem}_{self.gensym_counter-1}" def push_fn(self, fn: IRFunction) -> IRFunction: - self.fns.append(fn) prev_fn = self.fn self.restore_fn(fn) return prev_fn @@ -280,13 +300,19 @@ def compile(self, env: Env, exp: Object) -> Instr: return self.compile({**env, name: value}, body_exp) if isinstance(exp, MatchFunction): param = self.gensym("arg") - fn = self.new_function([param]) + clo = "$clo" + fn = self.new_function([clo, param]) prev_fn = self.push_fn(fn) self.block = fn.cfg.entry # funcenv = {} for idx, name in enumerate(fn.params): funcenv[name] = self.emit(Param(idx, name)) + closure = funcenv[clo] + freevars = sorted(free_in(exp)) + for idx, name in enumerate(freevars): + funcenv[name] = self.emit(ClosureRef(closure, idx, name)) + # no_match = self.fn.cfg.new_block() no_match.append(MatchFail()) case_blocks = [self.fn.cfg.new_block() for case in exp.cases] @@ -301,21 +327,29 @@ def compile(self, env: Env, exp: Object) -> Instr: self.compile_body({**funcenv, **env_updates}, case.body) # self.restore_fn(prev_fn) - return self.emit(NewClosure(fn)) + bound = [env[name] for name in freevars] + return self.emit(NewClosure(fn, bound)) if isinstance(exp, Function): assert isinstance(exp.arg, Var) param = exp.arg.name - fn = self.new_function([param]) + clo = "$clo" + fn = self.new_function([clo, param]) prev_fn = self.push_fn(fn) self.block = fn.cfg.entry # funcenv = {} for idx, name in enumerate(fn.params): funcenv[name] = self.emit(Param(idx, name)) + closure = funcenv[clo] + freevars = sorted(free_in(exp)) + for idx, name in enumerate(freevars): + funcenv[name] = self.emit(ClosureRef(closure, idx, name)) + # self.compile_body(funcenv, exp.body) # self.restore_fn(prev_fn) - return self.emit(NewClosure(fn)) + bound = [env[name] for name in freevars] + return self.emit(NewClosure(fn, bound)) raise NotImplementedError(f"exp {type(exp)} {exp}") @@ -405,7 +439,7 @@ def test_fun_id(self) -> None: """\ fn0 { bb0 { - v0 = NewClosure fn1 + v0 = NewClosure Return v0 } }""", @@ -415,9 +449,50 @@ def test_fun_id(self) -> None: """\ fn1 { bb0 { - v0 = Param<0; a> + v0 = Param<0; $clo> + v1 = Param<1; a> + Return v1 + } +}""", + ) + + def test_fun_closure(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("a -> b -> a + b")) + self.assertEqual(len(compiler.fns), 3) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewClosure Return v0 } +}""", + ) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; a> + v2 = NewClosure v1 + Return v2 + } +}""", + ) + self.assertEqual( + compiler.fns[2].to_string(InstrId()), + """\ +fn2 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; b> + v2 = ClosureRef<0; a> v0 + v3 = IntAdd v2, v1 + Return v3 + } }""", ) @@ -429,7 +504,7 @@ def test_match_no_cases(self) -> None: """\ fn0 { bb0 { - v0 = NewClosure fn1 + v0 = NewClosure Return v0 } }""", @@ -439,11 +514,12 @@ def test_match_no_cases(self) -> None: """\ fn1 { bb0 { - v0 = Param<0; arg_0> + v0 = Param<0; $clo> + v1 = Param<1; arg_0> Jump bb1 } bb1 { - v1 = MatchFail + v2 = MatchFail } }""", ) @@ -456,21 +532,22 @@ def test_match_one_case(self) -> None: """\ fn1 { bb0 { - v0 = Param<0; arg_0> + v0 = Param<0; $clo> + v1 = Param<1; arg_0> Jump bb2 } bb1 { - v1 = MatchFail + v2 = MatchFail } bb2 { - v2 = IsNumEqualWord v0, 1 - CondBranch v2, bb3, bb1 + v3 = IsNumEqualWord v1, 1 + CondBranch v3, bb3, bb1 } bb3 { - v3 = Const<2> - v4 = Const<3> - v5 = IntAdd v3, v4 - Return v5 + v4 = Const<2> + v5 = Const<3> + v6 = IntAdd v4, v5 + Return v6 } }""", ) @@ -483,27 +560,28 @@ def test_match_two_cases(self) -> None: """\ fn1 { bb0 { - v0 = Param<0; arg_0> + v0 = Param<0; $clo> + v1 = Param<1; arg_0> Jump bb2 } bb1 { - v1 = MatchFail + v2 = MatchFail } bb2 { - v2 = IsNumEqualWord v0, 1 - CondBranch v2, bb4, bb3 + v3 = IsNumEqualWord v1, 1 + CondBranch v3, bb4, bb3 } bb3 { - v3 = IsNumEqualWord v0, 3 - CondBranch v3, bb5, bb1 + v4 = IsNumEqualWord v1, 3 + CondBranch v4, bb5, bb1 } bb4 { - v4 = Const<2> - Return v4 + v5 = Const<2> + Return v5 } bb5 { - v5 = Const<4> - Return v5 + v6 = Const<4> + Return v6 } }""", ) @@ -516,19 +594,20 @@ def test_match_var(self) -> None: """\ fn1 { bb0 { - v0 = Param<0; arg_0> + v0 = Param<0; $clo> + v1 = Param<1; arg_0> Jump bb2 } bb1 { - v1 = MatchFail + v2 = MatchFail } bb2 { Jump bb3 } bb3 { - v2 = Const<1> - v3 = IntAdd v0, v2 - Return v3 + v3 = Const<1> + v4 = IntAdd v1, v3 + Return v4 } }""", ) From 3724850710a208733816e7824cb35d604e42081e Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 10:41:49 -0500 Subject: [PATCH 12/88] Add list pattern matching --- compiler.py | 10 ++-- ir.py | 162 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 6 deletions(-) diff --git a/compiler.py b/compiler.py index 81a9be32..19a31c7a 100644 --- a/compiler.py +++ b/compiler.py @@ -220,21 +220,19 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}") updates = {} the_list = arg - use_spread = False for i, pattern_item in enumerate(pattern.items): if isinstance(pattern_item, Spread): - use_spread = True if pattern_item.name: updates[pattern_item.name] = the_list - break + return updates # Not enough elements self._emit(f"if (is_empty_list({the_list})) {{ goto {fallthrough}; }}") list_item = self._mktemp(f"list_first({the_list})") + # Recursive pattern match updates.update(self.try_match(env, list_item, pattern_item, fallthrough)) the_list = self._mktemp(f"list_rest({the_list})") - if not use_spread: - # Too many elements - self._emit(f"if (!is_empty_list({the_list})) {{ goto {fallthrough}; }}") + # Too many elements + self._emit(f"if (!is_empty_list({the_list})) {{ goto {fallthrough}; }}") return updates if isinstance(pattern, Record): self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") diff --git a/ir.py b/ir.py index a8170534..d05f2f57 100644 --- a/ir.py +++ b/ir.py @@ -134,6 +134,26 @@ def to_string(self, gvn: InstrId) -> str: return f"{type(self).__name__}<{self.idx}; {self.name}> v{gvn[self.operands[0]]}" +@dataclasses.dataclass(init=False, eq=False) +class IsList(HasOperands): + pass + + +@dataclasses.dataclass(init=False, eq=False) +class IsEmptyList(HasOperands): + pass + + +@dataclasses.dataclass(init=False, eq=False) +class ListFirst(HasOperands): + pass + + +@dataclasses.dataclass(init=False, eq=False) +class ListRest(HasOperands): + pass + + @dataclasses.dataclass(eq=False) class Control(Instr): pass @@ -276,6 +296,30 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success if isinstance(pattern, Var): self.emit(Jump(success)) return {pattern.name: param} + if isinstance(pattern, List): + is_list = self.emit(IsList(param)) + is_list_block = self.fn.cfg.new_block() + self.emit(CondBranch(is_list, is_list_block, fallthrough)) + self.block = is_list_block + updates = {} + the_list = param + for i, pattern_item in enumerate(pattern.items): + assert not isinstance(pattern_item, Spread) + # Not enough elements + is_empty = self.emit(IsEmptyList(the_list)) + is_nonempty_block = self.fn.cfg.new_block() + self.emit(CondBranch(is_empty, fallthrough, is_nonempty_block)) + self.block = is_nonempty_block + list_item = self.emit(ListFirst(the_list)) + pattern_success = self.fn.cfg.new_block() + # Recursive pattern match + updates.update(self.compile_match_pattern(env, list_item, pattern_item, pattern_success, fallthrough)) + self.block = pattern_success + the_list = self.emit(ListRest(the_list)) + # Too many elements + is_empty = self.emit(IsEmptyList(the_list)) + self.emit(CondBranch(is_empty, success, fallthrough)) + return updates raise NotImplementedError(f"pattern {type(pattern)} {pattern}") def compile_body(self, env: Env, exp: Object) -> None: @@ -612,6 +656,124 @@ def test_match_var(self) -> None: }""", ) + def test_match_empty_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("| [] -> 1")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb1 { + v2 = MatchFail + } + bb2 { + v3 = IsList v1 + CondBranch v3, bb4, bb1 + } + bb3 { + v4 = Const<1> + Return v4 + } + bb4 { + v5 = IsEmptyList v1 + CondBranch v5, bb3, bb1 + } +}""", + ) + + def test_match_one_item_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("| [a] -> a + 1")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb1 { + v2 = MatchFail + } + bb2 { + v3 = IsList v1 + CondBranch v3, bb4, bb1 + } + bb3 { + v4 = Const<1> + v5 = IntAdd v6, v4 + Return v5 + } + bb4 { + v7 = IsEmptyList v1 + CondBranch v7, bb1, bb5 + } + bb5 { + v6 = ListFirst v1 + Jump bb6 + } + bb6 { + v8 = ListRest v1 + v9 = IsEmptyList v8 + CondBranch v9, bb3, bb1 + } +}""", + ) + + def test_match_two_item_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("| [a, b] -> a + b")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb1 { + v2 = MatchFail + } + bb2 { + v3 = IsList v1 + CondBranch v3, bb4, bb1 + } + bb3 { + v4 = IntAdd v5, v6 + Return v4 + } + bb4 { + v7 = IsEmptyList v1 + CondBranch v7, bb1, bb5 + } + bb5 { + v5 = ListFirst v1 + Jump bb6 + } + bb6 { + v8 = ListRest v1 + v9 = IsEmptyList v8 + CondBranch v9, bb1, bb7 + } + bb7 { + v6 = ListFirst v8 + Jump bb8 + } + bb8 { + v10 = ListRest v8 + v11 = IsEmptyList v10 + CondBranch v11, bb3, bb1 + } +}""", + ) + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 2f72953217f0684dd8318682fe5b1980008ade2a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 11:38:28 -0500 Subject: [PATCH 13/88] . --- ir.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ir.py b/ir.py index d05f2f57..70ca486c 100644 --- a/ir.py +++ b/ir.py @@ -103,6 +103,11 @@ class IntAdd(HasOperands): pass +@dataclasses.dataclass(init=False, eq=False) +class IntSub(HasOperands): + pass + + @dataclasses.dataclass(init=False, eq=False) class IntLess(HasOperands): pass @@ -335,6 +340,8 @@ def compile(self, env: Env, exp: Object) -> Instr: right = self.compile(env, exp.right) if exp.op == BinopKind.ADD: return self.emit(IntAdd(left, right)) + if exp.op == BinopKind.SUB: + return self.emit(IntSub(left, right)) if exp.op == BinopKind.LESS: return self.emit(IntLess(left, right)) if isinstance(exp, Where): @@ -445,6 +452,22 @@ def test_add_int(self) -> None: }""", ) + def test_sub_int(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("1 - 2")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + v1 = Const<2> + v2 = IntSub v0, v1 + Return v2 + } +}""", + ) + def test_less_int(self) -> None: compiler = Compiler() compiler.compile_body({}, self._parse("1 < 2")) From 6953478a345b60b4d437637286daf4708a78b2f1 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 11:44:00 -0500 Subject: [PATCH 14/88] List cons --- ir.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/ir.py b/ir.py index 70ca486c..8f980f92 100644 --- a/ir.py +++ b/ir.py @@ -149,6 +149,11 @@ class IsEmptyList(HasOperands): pass +@dataclasses.dataclass(init=False, eq=False) +class ListCons(HasOperands): + pass + + @dataclasses.dataclass(init=False, eq=False) class ListFirst(HasOperands): pass @@ -344,6 +349,14 @@ def compile(self, env: Env, exp: Object) -> Instr: return self.emit(IntSub(left, right)) if exp.op == BinopKind.LESS: return self.emit(IntLess(left, right)) + if isinstance(exp, List): + result = self.emit(Const(List([]))) + if not exp.items: + return result + for elt_exp in reversed(exp.items): + elt = self.compile(env, elt_exp) + result = self.emit(ListCons(elt, result)) + return result if isinstance(exp, Where): assert isinstance(exp.binding, Assign) name, value_exp, body_exp = exp.binding.name.name, exp.binding.value, exp.body @@ -484,6 +497,38 @@ def test_less_int(self) -> None: }""", ) + def test_empty_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("[]")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<[]> + Return v0 + } +}""", + ) + + def test_const_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("[1, 2]")) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<[]> + v1 = Const<2> + v2 = ListCons v1, v0 + v3 = Const<1> + v4 = ListCons v3, v2 + Return v4 + } +}""", + ) + def test_let(self) -> None: compiler = Compiler() compiler.compile_body({}, self._parse("a . a = 1")) From 1d3c948d84f2b625b3ba021dedefd7e08c70e163 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 11:45:03 -0500 Subject: [PATCH 15/88] Test non-const list --- ir.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/ir.py b/ir.py index 8f980f92..f8179a05 100644 --- a/ir.py +++ b/ir.py @@ -529,6 +529,23 @@ def test_const_list(self) -> None: }""", ) + def test_non_const_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("a -> [a]")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; a> + v2 = Const<[]> + v3 = ListCons v1, v2 + Return v3 + } +}""", + ) + def test_let(self) -> None: compiler = Compiler() compiler.compile_body({}, self._parse("a . a = 1")) From 85067e8445418d5bac1a6884b1f294999faf621e Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 11:50:12 -0500 Subject: [PATCH 16/88] Support call --- ir.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ir.py b/ir.py index f8179a05..4f3075e1 100644 --- a/ir.py +++ b/ir.py @@ -164,6 +164,11 @@ class ListRest(HasOperands): pass +@dataclasses.dataclass(init=False, eq=False) +class Call(HasOperands): + pass + + @dataclasses.dataclass(eq=False) class Control(Instr): pass @@ -362,6 +367,10 @@ def compile(self, env: Env, exp: Object) -> Instr: name, value_exp, body_exp = exp.binding.name.name, exp.binding.value, exp.body value = self.compile(env, value_exp) return self.compile({**env, name: value}, body_exp) + if isinstance(exp, Apply): + fn = self.compile(env, exp.func) + arg = self.compile(env, exp.arg) + return self.emit(Call(fn, arg)) if isinstance(exp, MatchFunction): param = self.gensym("arg") clo = "$clo" @@ -859,6 +868,21 @@ def test_match_two_item_list(self) -> None: }""", ) + def test_apply_fn(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("f 1 . f = x -> x + 1")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewClosure + v1 = Const<1> + v2 = Call v0, v1 + Return v2 + } +}""") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 05fc0d2282748d0270141a019240fcd76bd92677 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 11:53:12 -0500 Subject: [PATCH 17/88] . --- ir.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ir.py b/ir.py index 4f3075e1..1c8310b6 100644 --- a/ir.py +++ b/ir.py @@ -883,6 +883,21 @@ def test_apply_fn(self) -> None: } }""") + def test_apply_anonymous_function(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("((x -> x + 1) 1)")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewClosure + v1 = Const<1> + v2 = Call v0, v1 + Return v2 + } +}""") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 3eefe3a8312c269eb5afba655f4c31426ba28009 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 15:25:28 -0500 Subject: [PATCH 18/88] Add int mul --- ir.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ir.py b/ir.py index 1c8310b6..bf23a507 100644 --- a/ir.py +++ b/ir.py @@ -108,6 +108,11 @@ class IntSub(HasOperands): pass +@dataclasses.dataclass(init=False, eq=False) +class IntMul(HasOperands): + pass + + @dataclasses.dataclass(init=False, eq=False) class IntLess(HasOperands): pass @@ -352,6 +357,8 @@ def compile(self, env: Env, exp: Object) -> Instr: return self.emit(IntAdd(left, right)) if exp.op == BinopKind.SUB: return self.emit(IntSub(left, right)) + if exp.op == BinopKind.MUL: + return self.emit(IntMul(left, right)) if exp.op == BinopKind.LESS: return self.emit(IntLess(left, right)) if isinstance(exp, List): From ee96fcbacc79e9461fe179d1f06b7678772ee056 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 23 Jan 2025 15:35:31 -0500 Subject: [PATCH 19/88] Let functions refer to themselves; support recursion --- ir.py | 160 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 105 insertions(+), 55 deletions(-) diff --git a/ir.py b/ir.py index bf23a507..53132cbc 100644 --- a/ir.py +++ b/ir.py @@ -345,6 +345,52 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success def compile_body(self, env: Env, exp: Object) -> None: self.emit(Return(self.compile(env, exp))) + def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: Optional[str]) -> Instr: + if isinstance(exp, Function): + assert isinstance(exp.arg, Var) + param = exp.arg.name + else: + param = self.gensym("arg") + clo = "$clo" + fn = self.new_function([clo, param]) + freevars = free_in(exp) + if func_name is not None and func_name in freevars: + # Functions can refer to themselves; we close the loop below in the + # funcenv + freevars.remove(func_name) + freevars = sorted(freevars) + bound = [env[name] for name in freevars] + result = self.emit(NewClosure(fn, bound)) + prev_fn = self.push_fn(fn) + self.block = fn.cfg.entry + # + funcenv = {} + for idx, name in enumerate(fn.params): + funcenv[name] = self.emit(Param(idx, name)) + closure = funcenv[clo] + if func_name is not None: + funcenv[func_name] = closure + for idx, name in enumerate(freevars): + funcenv[name] = self.emit(ClosureRef(closure, idx, name)) + # + if isinstance(exp, Function): + self.compile_body(funcenv, exp.body) + else: + no_match = self.fn.cfg.new_block() + no_match.append(MatchFail()) + case_blocks = [self.fn.cfg.new_block() for case in exp.cases] + case_blocks.append(no_match) + self.emit(Jump(case_blocks[0])) + for i, case in enumerate(exp.cases): + self.block = case_blocks[i] + fallthrough = case_blocks[i + 1] + body_block = self.fn.cfg.new_block() + env_updates = self.compile_match_pattern(funcenv, funcenv[param], case.pattern, body_block, fallthrough) + self.block = body_block + self.compile_body({**funcenv, **env_updates}, case.body) + self.restore_fn(prev_fn) + return result + def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, (Int, String)): return self.emit(Const(exp)) @@ -372,64 +418,18 @@ def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, Where): assert isinstance(exp.binding, Assign) name, value_exp, body_exp = exp.binding.name.name, exp.binding.value, exp.body - value = self.compile(env, value_exp) + if isinstance(value_exp, (Function, MatchFunction)): + value = self.compile_function(env, value_exp, func_name=name) + else: + value = self.compile(env, value_exp) return self.compile({**env, name: value}, body_exp) if isinstance(exp, Apply): fn = self.compile(env, exp.func) arg = self.compile(env, exp.arg) return self.emit(Call(fn, arg)) - if isinstance(exp, MatchFunction): - param = self.gensym("arg") - clo = "$clo" - fn = self.new_function([clo, param]) - prev_fn = self.push_fn(fn) - self.block = fn.cfg.entry - # - funcenv = {} - for idx, name in enumerate(fn.params): - funcenv[name] = self.emit(Param(idx, name)) - closure = funcenv[clo] - freevars = sorted(free_in(exp)) - for idx, name in enumerate(freevars): - funcenv[name] = self.emit(ClosureRef(closure, idx, name)) - # - no_match = self.fn.cfg.new_block() - no_match.append(MatchFail()) - case_blocks = [self.fn.cfg.new_block() for case in exp.cases] - case_blocks.append(no_match) - self.emit(Jump(case_blocks[0])) - for i, case in enumerate(exp.cases): - self.block = case_blocks[i] - fallthrough = case_blocks[i + 1] - body_block = self.fn.cfg.new_block() - env_updates = self.compile_match_pattern(funcenv, funcenv[param], case.pattern, body_block, fallthrough) - self.block = body_block - self.compile_body({**funcenv, **env_updates}, case.body) - # - self.restore_fn(prev_fn) - bound = [env[name] for name in freevars] - return self.emit(NewClosure(fn, bound)) - if isinstance(exp, Function): - assert isinstance(exp.arg, Var) - param = exp.arg.name - clo = "$clo" - fn = self.new_function([clo, param]) - prev_fn = self.push_fn(fn) - self.block = fn.cfg.entry - # - funcenv = {} - for idx, name in enumerate(fn.params): - funcenv[name] = self.emit(Param(idx, name)) - closure = funcenv[clo] - freevars = sorted(free_in(exp)) - for idx, name in enumerate(freevars): - funcenv[name] = self.emit(ClosureRef(closure, idx, name)) - # - self.compile_body(funcenv, exp.body) - # - self.restore_fn(prev_fn) - bound = [env[name] for name in freevars] - return self.emit(NewClosure(fn, bound)) + if isinstance(exp, (Function, MatchFunction)): + # Anonymous function + return self.compile_function(env, exp, func_name=None) raise NotImplementedError(f"exp {type(exp)} {exp}") @@ -888,7 +888,56 @@ def test_apply_fn(self) -> None: v2 = Call v0, v1 Return v2 } -}""") +}""", + ) + + def test_recursive_call(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("fact 5 . fact = | 0 -> 1 | n -> n * fact (n - 1)")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewClosure + v1 = Const<5> + v2 = Call v0, v1 + Return v2 + } +}""", + ) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb1 { + v2 = MatchFail + } + bb2 { + v3 = IsNumEqualWord v1, 0 + CondBranch v3, bb4, bb3 + } + bb3 { + Jump bb5 + } + bb4 { + v4 = Const<1> + Return v4 + } + bb5 { + v5 = Const<1> + v6 = IntSub v1, v5 + v7 = Call v0, v6 + v8 = IntMul v1, v7 + Return v8 + } +}""", + ) def test_apply_anonymous_function(self) -> None: compiler = Compiler() @@ -903,7 +952,8 @@ def test_apply_anonymous_function(self) -> None: v2 = Call v0, v1 Return v2 } -}""") +}""", + ) if __name__ == "__main__": From 270a983a191ef03bfa9b71801b29cb2eec3e8676 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 14:06:17 -0500 Subject: [PATCH 20/88] . --- ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index 53132cbc..ade593f1 100644 --- a/ir.py +++ b/ir.py @@ -359,8 +359,6 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O # funcenv freevars.remove(func_name) freevars = sorted(freevars) - bound = [env[name] for name in freevars] - result = self.emit(NewClosure(fn, bound)) prev_fn = self.push_fn(fn) self.block = fn.cfg.entry # @@ -389,6 +387,8 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O self.block = body_block self.compile_body({**funcenv, **env_updates}, case.body) self.restore_fn(prev_fn) + bound = [env[name] for name in freevars] + result = self.emit(NewClosure(fn, bound)) return result def compile(self, env: Env, exp: Object) -> Instr: From 71fb8870c209b7ceccb0f534f207d38ff39dfd73 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 14:06:44 -0500 Subject: [PATCH 21/88] . --- ir.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ir.py b/ir.py index ade593f1..81fc7926 100644 --- a/ir.py +++ b/ir.py @@ -360,7 +360,6 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O freevars.remove(func_name) freevars = sorted(freevars) prev_fn = self.push_fn(fn) - self.block = fn.cfg.entry # funcenv = {} for idx, name in enumerate(fn.params): From 528557ab4d9118fe2527e0e57c4798956edf2276 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 14:08:25 -0500 Subject: [PATCH 22/88] . --- ir.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ir.py b/ir.py index 81fc7926..21dcf672 100644 --- a/ir.py +++ b/ir.py @@ -297,12 +297,13 @@ def gensym(self, stem: str = "tmp") -> str: def push_fn(self, fn: IRFunction) -> IRFunction: prev_fn = self.fn - self.restore_fn(fn) - return prev_fn + prev_block = self.block + self.restore_fn(fn, fn.cfg.entry) + return prev_fn, prev_block - def restore_fn(self, fn: IRFunction) -> None: + def restore_fn(self, fn: IRFunction, block: Block) -> None: self.fn = fn - self.block = fn.cfg.entry + self.block = block def emit(self, instr: Instr) -> Instr: self.block.append(instr) @@ -359,7 +360,7 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O # funcenv freevars.remove(func_name) freevars = sorted(freevars) - prev_fn = self.push_fn(fn) + prev_fn, prev_block = self.push_fn(fn) # funcenv = {} for idx, name in enumerate(fn.params): @@ -385,7 +386,7 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O env_updates = self.compile_match_pattern(funcenv, funcenv[param], case.pattern, body_block, fallthrough) self.block = body_block self.compile_body({**funcenv, **env_updates}, case.body) - self.restore_fn(prev_fn) + self.restore_fn(prev_fn, prev_block) bound = [env[name] for name in freevars] result = self.emit(NewClosure(fn, bound)) return result From 528698fc05cffa50f5a9b5f9420805e02f8067eb Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 14:11:49 -0500 Subject: [PATCH 23/88] . --- ir.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ir.py b/ir.py index 21dcf672..ae83ca9a 100644 --- a/ir.py +++ b/ir.py @@ -641,6 +641,35 @@ def test_fun_closure(self) -> None: }""", ) + def test_fun_const_closure(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("(a -> a + b) . b = 1")) + self.assertEqual(len(compiler.fns), 2) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + v1 = NewClosure v0 + Return v1 + } +}""", + ) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; a> + v2 = ClosureRef<0; b> v0 + v3 = IntAdd v1, v2 + Return v3 + } +}""", + ) + def test_match_no_cases(self) -> None: compiler = Compiler() compiler.compile_body({}, MatchFunction([])) From 4d25aebea2d8400dcad9dd6e32fe1517e5b2fdd5 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 17:30:20 -0500 Subject: [PATCH 24/88] Add silly dominator analysis --- ir.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/ir.py b/ir.py index ae83ca9a..80632ff6 100644 --- a/ir.py +++ b/ir.py @@ -433,6 +433,23 @@ def compile(self, env: Env, exp: Object) -> Instr: raise NotImplementedError(f"exp {type(exp)} {exp}") +def compute_doms(preds: dict[str, set[str]]) -> dict[str, set[str]]: + entry = [block for block, block_preds in preds.items() if not block_preds][0] + other_blocks = set(preds.keys()) - {entry} + result = {entry: {entry}} + for block in other_blocks: + result[block] = set(preds.keys()) + change = True + while change: + change = False + for block in other_blocks: + tmp = {block} | set.intersection(*(result[pred] for pred in preds[block])) + if tmp != result[block]: + result[block] = tmp + change = True + return result + + class IRTests(unittest.TestCase): def _parse(self, source: str) -> Object: return parse(tokenize(source)) @@ -985,6 +1002,38 @@ def test_apply_anonymous_function(self) -> None: ) +class DominatorTests(unittest.TestCase): + def test_dom(self) -> None: + entry = "entry" + blocks = ["entry", *(f"bb{n+1}" for n in range(7)), "exit"] + preds = { + blocks[0]: set(), + blocks[1]: {entry}, + blocks[2]: {blocks[1]}, + blocks[3]: {blocks[1]}, + blocks[4]: {blocks[2], blocks[3], blocks[7]}, + blocks[5]: {blocks[4]}, + blocks[6]: {blocks[4]}, + blocks[7]: {blocks[5], blocks[6]}, + blocks[-1]: {blocks[7]}, + } + doms = compute_doms(preds) + self.assertEqual( + doms, + { + "entry": {"entry"}, + "bb1": {"bb1", "entry"}, + "bb2": {"bb1", "entry", "bb2"}, + "bb3": {"bb3", "bb1", "entry"}, + "bb4": {"bb4", "bb1", "entry"}, + "bb5": {"bb4", "bb1", "bb5", "entry"}, + "bb6": {"bb4", "bb1", "bb6", "entry"}, + "bb7": {"bb4", "bb1", "entry", "bb7"}, + "exit": {"bb4", "bb1", "entry", "exit", "bb7"}, + }, + ) + + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() From c7f438b1af2dbe7e3a0443bb379e522923a1193e Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 17:41:19 -0500 Subject: [PATCH 25/88] Add RPO traversal tests --- ir.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index 80632ff6..7b8b53e2 100644 --- a/ir.py +++ b/ir.py @@ -202,12 +202,18 @@ class Block: id: int instrs: list[Instr] = dataclasses.field(init=False, default_factory=list) - def append(self, instr: Instr) -> None: + def append(self, instr: Instr) -> Instr: self.instrs.append(instr) + return instr def name(self) -> str: return f"bb{self.id}" + def terminator(self) -> Control: + result = self.instrs[-1] + assert isinstance(result, Control) + return result + @dataclasses.dataclass(eq=False) class Jump(Control): @@ -262,6 +268,29 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str: result += " }\n" return result + def rpo(self) -> list[Block]: + result = [] + self.po_from(self.entry, result, set()) + result.reverse() + return result + + def po_from(self, block: Block, result: list[Block], visited: set[Block]): + visited.add(block) + terminator = block.terminator() + if isinstance(terminator, Jump): + if terminator.target not in visited: + self.po_from(terminator.target, result, visited) + elif isinstance(terminator, CondBranch): + if terminator.conseq not in visited: + self.po_from(terminator.conseq, result, visited) + if terminator.alt not in visited: + self.po_from(terminator.alt, result, visited) + elif isinstance(terminator, Return): + pass + else: + raise NotImplementedError(f"unexpected terminator {terminator}") + result.append(block) + @dataclasses.dataclass(eq=False) class IRFunction: @@ -1002,6 +1031,35 @@ def test_apply_anonymous_function(self) -> None: ) +class RPOTests(unittest.TestCase): + def test_one_block(self) -> None: + fn = IRFunction(0, []) + entry = fn.cfg.entry + one = entry.append(Const(1)) + entry.append(Return(one)) + self.assertEqual(fn.cfg.rpo(), [entry]) + + def test_jump(self) -> None: + fn = IRFunction(0, []) + entry = fn.cfg.entry + one = entry.append(Const(1)) + exit = fn.cfg.new_block() + entry.append(Jump(exit)) + exit.append(Return(one)) + self.assertEqual(fn.cfg.rpo(), [entry, exit]) + + def test_cond_branch(self) -> None: + fn = IRFunction(0, []) + entry = fn.cfg.entry + one = entry.append(Const(1)) + left = fn.cfg.new_block() + right = fn.cfg.new_block() + entry.append(CondBranch(one, left, right)) + left.append(Return(one)) + right.append(Return(one)) + self.assertEqual(fn.cfg.rpo(), [entry, right, left]) + + class DominatorTests(unittest.TestCase): def test_dom(self) -> None: entry = "entry" From c227a8503ff9c8eafb81c69a183135d7ae05302d Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 17:51:54 -0500 Subject: [PATCH 26/88] Compute preds --- ir.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/ir.py b/ir.py index 7b8b53e2..402aea76 100644 --- a/ir.py +++ b/ir.py @@ -176,7 +176,8 @@ class Call(HasOperands): @dataclasses.dataclass(eq=False) class Control(Instr): - pass + def succs(self) -> tuple[Block, ...]: + raise NotImplementedError("succs") @dataclasses.dataclass(eq=False) @@ -222,11 +223,17 @@ class Jump(Control): def to_string(self, gvn: InstrId) -> str: return super().to_string(gvn) + f" {self.target.name()}" + def succs(self) -> tuple[Block, ...]: + return (self.target,) + @dataclasses.dataclass(init=False, eq=False) class Return(HasOperands, Control): pass + def succs(self) -> tuple[Block, ...]: + return () + @dataclasses.dataclass(init=False, eq=False) class CondBranch(Control, HasOperands): @@ -241,6 +248,9 @@ def __init__(self, cond: Instr, conseq: Block, alt: Block) -> None: def to_string(self, gvn: InstrId) -> str: return super().to_string(gvn) + f", {self.conseq.name()}, {self.alt.name()}" + def succs(self) -> tuple[Block, ...]: + return (self.conseq, self.alt) + @dataclasses.dataclass class CFG: @@ -277,20 +287,19 @@ def rpo(self) -> list[Block]: def po_from(self, block: Block, result: list[Block], visited: set[Block]): visited.add(block) terminator = block.terminator() - if isinstance(terminator, Jump): - if terminator.target not in visited: - self.po_from(terminator.target, result, visited) - elif isinstance(terminator, CondBranch): - if terminator.conseq not in visited: - self.po_from(terminator.conseq, result, visited) - if terminator.alt not in visited: - self.po_from(terminator.alt, result, visited) - elif isinstance(terminator, Return): - pass - else: - raise NotImplementedError(f"unexpected terminator {terminator}") + for succ in terminator.succs(): + if succ not in visited: + self.po_from(succ, result, visited) result.append(block) + def preds(self) -> dict[Block, set[Block]]: + rpo = self.rpo() + result = {block: set() for block in rpo} + for block in rpo: + for succ in block.terminator().succs(): + result[succ].add(block) + return result + @dataclasses.dataclass(eq=False) class IRFunction: @@ -1060,6 +1069,49 @@ def test_cond_branch(self) -> None: self.assertEqual(fn.cfg.rpo(), [entry, right, left]) +class PredTests(unittest.TestCase): + def test_preds(self) -> None: + fn = IRFunction(0, []) + entry = fn.cfg.entry + one = entry.append(Const(1)) + bb1 = fn.cfg.new_block() + entry.append(Jump(bb1)) + two = bb1.append(Const(2)) + bb2 = fn.cfg.new_block() + bb3 = fn.cfg.new_block() + bb1.append(CondBranch(two, bb2, bb3)) + bb4 = fn.cfg.new_block() + bb2.append(Jump(bb4)) + bb3.append(Jump(bb4)) + three = bb4.append(Const(3)) + bb5 = fn.cfg.new_block() + bb6 = fn.cfg.new_block() + bb4.append(CondBranch(three, bb5, bb6)) + bb7 = fn.cfg.new_block() + bb5.append(Jump(bb7)) + bb6.append(Jump(bb7)) + four = bb7.append(Const(4)) + exit = fn.cfg.new_block() + bb7.append(CondBranch(four, exit, bb4)) + five = exit.append(Const(5)) + exit.append(Return(five)) + preds = fn.cfg.preds() + self.assertEqual( + preds, + { + entry: set(), + bb1: {entry}, + bb2: {bb1}, + bb3: {bb1}, + bb4: {bb2, bb3, bb7}, + bb5: {bb4}, + bb6: {bb4}, + bb7: {bb5, bb6}, + exit: {bb7}, + }, + ) + + class DominatorTests(unittest.TestCase): def test_dom(self) -> None: entry = "entry" From 2c156802bb5052e39c065debd6fb0847a1309b3b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 24 Jan 2025 17:54:39 -0500 Subject: [PATCH 27/88] Run dominator tests on IR --- ir.py | 91 +++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/ir.py b/ir.py index 402aea76..b38dd3e1 100644 --- a/ir.py +++ b/ir.py @@ -300,6 +300,23 @@ def preds(self) -> dict[Block, set[Block]]: result[succ].add(block) return result + def doms(self) -> dict[Block, set[Block]]: + preds = self.preds() + entry = [block for block, block_preds in preds.items() if not block_preds][0] + other_blocks = set(preds.keys()) - {entry} + result = {entry: {entry}} + for block in other_blocks: + result[block] = set(preds.keys()) + change = True + while change: + change = False + for block in other_blocks: + tmp = {block} | set.intersection(*(result[pred] for pred in preds[block])) + if tmp != result[block]: + result[block] = tmp + change = True + return result + @dataclasses.dataclass(eq=False) class IRFunction: @@ -471,23 +488,6 @@ def compile(self, env: Env, exp: Object) -> Instr: raise NotImplementedError(f"exp {type(exp)} {exp}") -def compute_doms(preds: dict[str, set[str]]) -> dict[str, set[str]]: - entry = [block for block, block_preds in preds.items() if not block_preds][0] - other_blocks = set(preds.keys()) - {entry} - result = {entry: {entry}} - for block in other_blocks: - result[block] = set(preds.keys()) - change = True - while change: - change = False - for block in other_blocks: - tmp = {block} | set.intersection(*(result[pred] for pred in preds[block])) - if tmp != result[block]: - result[block] = tmp - change = True - return result - - class IRTests(unittest.TestCase): def _parse(self, source: str) -> Object: return parse(tokenize(source)) @@ -1114,32 +1114,43 @@ def test_preds(self) -> None: class DominatorTests(unittest.TestCase): def test_dom(self) -> None: - entry = "entry" - blocks = ["entry", *(f"bb{n+1}" for n in range(7)), "exit"] - preds = { - blocks[0]: set(), - blocks[1]: {entry}, - blocks[2]: {blocks[1]}, - blocks[3]: {blocks[1]}, - blocks[4]: {blocks[2], blocks[3], blocks[7]}, - blocks[5]: {blocks[4]}, - blocks[6]: {blocks[4]}, - blocks[7]: {blocks[5], blocks[6]}, - blocks[-1]: {blocks[7]}, - } - doms = compute_doms(preds) + fn = IRFunction(0, []) + entry = fn.cfg.entry + one = entry.append(Const(1)) + bb1 = fn.cfg.new_block() + entry.append(Jump(bb1)) + two = bb1.append(Const(2)) + bb2 = fn.cfg.new_block() + bb3 = fn.cfg.new_block() + bb1.append(CondBranch(two, bb2, bb3)) + bb4 = fn.cfg.new_block() + bb2.append(Jump(bb4)) + bb3.append(Jump(bb4)) + three = bb4.append(Const(3)) + bb5 = fn.cfg.new_block() + bb6 = fn.cfg.new_block() + bb4.append(CondBranch(three, bb5, bb6)) + bb7 = fn.cfg.new_block() + bb5.append(Jump(bb7)) + bb6.append(Jump(bb7)) + four = bb7.append(Const(4)) + exit = fn.cfg.new_block() + bb7.append(CondBranch(four, exit, bb4)) + five = exit.append(Const(5)) + exit.append(Return(five)) + doms = fn.cfg.doms() self.assertEqual( doms, { - "entry": {"entry"}, - "bb1": {"bb1", "entry"}, - "bb2": {"bb1", "entry", "bb2"}, - "bb3": {"bb3", "bb1", "entry"}, - "bb4": {"bb4", "bb1", "entry"}, - "bb5": {"bb4", "bb1", "bb5", "entry"}, - "bb6": {"bb4", "bb1", "bb6", "entry"}, - "bb7": {"bb4", "bb1", "entry", "bb7"}, - "exit": {"bb4", "bb1", "entry", "exit", "bb7"}, + entry: {entry}, + bb1: {bb1, entry}, + bb2: {bb1, entry, bb2}, + bb3: {bb3, bb1, entry}, + bb4: {bb4, bb1, entry}, + bb5: {bb4, bb1, bb5, entry}, + bb6: {bb4, bb1, bb6, entry}, + bb7: {bb4, bb1, entry, bb7}, + exit: {bb4, bb1, entry, exit, bb7}, }, ) From 075dc2d6e8efc9ad6166cd4d904914201c3df155 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 12:01:43 -0500 Subject: [PATCH 28/88] Fix mypy --- ir.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/ir.py b/ir.py index b38dd3e1..cb19d527 100644 --- a/ir.py +++ b/ir.py @@ -279,12 +279,12 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str: return result def rpo(self) -> list[Block]: - result = [] + result: list[Block] = [] self.po_from(self.entry, result, set()) result.reverse() return result - def po_from(self, block: Block, result: list[Block], visited: set[Block]): + def po_from(self, block: Block, result: list[Block], visited: set[Block]) -> None: visited.add(block) terminator = block.terminator() for succ in terminator.succs(): @@ -294,7 +294,7 @@ def po_from(self, block: Block, result: list[Block], visited: set[Block]): def preds(self) -> dict[Block, set[Block]]: rpo = self.rpo() - result = {block: set() for block in rpo} + result: dict[Block, set[Block]] = {block: set() for block in rpo} for block in rpo: for succ in block.terminator().succs(): result[succ].add(block) @@ -350,7 +350,7 @@ def gensym(self, stem: str = "tmp") -> str: self.gensym_counter += 1 return f"{stem}_{self.gensym_counter-1}" - def push_fn(self, fn: IRFunction) -> IRFunction: + def push_fn(self, fn: IRFunction) -> tuple[IRFunction, Block]: prev_fn = self.fn prev_block = self.block self.restore_fn(fn, fn.cfg.entry) @@ -414,7 +414,7 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O # Functions can refer to themselves; we close the loop below in the # funcenv freevars.remove(func_name) - freevars = sorted(freevars) + ordered_freevars = sorted(freevars) prev_fn, prev_block = self.push_fn(fn) # funcenv = {} @@ -423,7 +423,7 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O closure = funcenv[clo] if func_name is not None: funcenv[func_name] = closure - for idx, name in enumerate(freevars): + for idx, name in enumerate(ordered_freevars): funcenv[name] = self.emit(ClosureRef(closure, idx, name)) # if isinstance(exp, Function): @@ -442,7 +442,7 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O self.block = body_block self.compile_body({**funcenv, **env_updates}, case.body) self.restore_fn(prev_fn, prev_block) - bound = [env[name] for name in freevars] + bound = [env[name] for name in ordered_freevars] result = self.emit(NewClosure(fn, bound)) return result @@ -1044,14 +1044,14 @@ class RPOTests(unittest.TestCase): def test_one_block(self) -> None: fn = IRFunction(0, []) entry = fn.cfg.entry - one = entry.append(Const(1)) + one = entry.append(Const(Int(1))) entry.append(Return(one)) self.assertEqual(fn.cfg.rpo(), [entry]) def test_jump(self) -> None: fn = IRFunction(0, []) entry = fn.cfg.entry - one = entry.append(Const(1)) + one = entry.append(Const(Int(1))) exit = fn.cfg.new_block() entry.append(Jump(exit)) exit.append(Return(one)) @@ -1060,7 +1060,7 @@ def test_jump(self) -> None: def test_cond_branch(self) -> None: fn = IRFunction(0, []) entry = fn.cfg.entry - one = entry.append(Const(1)) + one = entry.append(Const(Int(1))) left = fn.cfg.new_block() right = fn.cfg.new_block() entry.append(CondBranch(one, left, right)) @@ -1073,27 +1073,27 @@ class PredTests(unittest.TestCase): def test_preds(self) -> None: fn = IRFunction(0, []) entry = fn.cfg.entry - one = entry.append(Const(1)) + one = entry.append(Const(Int(1))) bb1 = fn.cfg.new_block() entry.append(Jump(bb1)) - two = bb1.append(Const(2)) + two = bb1.append(Const(Int(2))) bb2 = fn.cfg.new_block() bb3 = fn.cfg.new_block() bb1.append(CondBranch(two, bb2, bb3)) bb4 = fn.cfg.new_block() bb2.append(Jump(bb4)) bb3.append(Jump(bb4)) - three = bb4.append(Const(3)) + three = bb4.append(Const(Int(3))) bb5 = fn.cfg.new_block() bb6 = fn.cfg.new_block() bb4.append(CondBranch(three, bb5, bb6)) bb7 = fn.cfg.new_block() bb5.append(Jump(bb7)) bb6.append(Jump(bb7)) - four = bb7.append(Const(4)) + four = bb7.append(Const(Int(4))) exit = fn.cfg.new_block() bb7.append(CondBranch(four, exit, bb4)) - five = exit.append(Const(5)) + five = exit.append(Const(Int(5))) exit.append(Return(five)) preds = fn.cfg.preds() self.assertEqual( @@ -1116,27 +1116,27 @@ class DominatorTests(unittest.TestCase): def test_dom(self) -> None: fn = IRFunction(0, []) entry = fn.cfg.entry - one = entry.append(Const(1)) + one = entry.append(Const(Int(1))) bb1 = fn.cfg.new_block() entry.append(Jump(bb1)) - two = bb1.append(Const(2)) + two = bb1.append(Const(Int(2))) bb2 = fn.cfg.new_block() bb3 = fn.cfg.new_block() bb1.append(CondBranch(two, bb2, bb3)) bb4 = fn.cfg.new_block() bb2.append(Jump(bb4)) bb3.append(Jump(bb4)) - three = bb4.append(Const(3)) + three = bb4.append(Const(Int(3))) bb5 = fn.cfg.new_block() bb6 = fn.cfg.new_block() bb4.append(CondBranch(three, bb5, bb6)) bb7 = fn.cfg.new_block() bb5.append(Jump(bb7)) bb6.append(Jump(bb7)) - four = bb7.append(Const(4)) + four = bb7.append(Const(Int(4))) exit = fn.cfg.new_block() bb7.append(CondBranch(four, exit, bb4)) - five = exit.append(Const(5)) + five = exit.append(Const(Int(5))) exit.append(Return(five)) doms = fn.cfg.doms() self.assertEqual( From f7fee26029a385cff4cca530c00c9bd32e7d11f9 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 12:02:22 -0500 Subject: [PATCH 29/88] Start SCCP --- ir.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/ir.py b/ir.py index cb19d527..1556a3b1 100644 --- a/ir.py +++ b/ir.py @@ -488,6 +488,91 @@ def compile(self, env: Env, exp: Object) -> Instr: raise NotImplementedError(f"exp {type(exp)} {exp}") +@dataclasses.dataclass +class ConstantLattice: + pass + + +@dataclasses.dataclass +class CBottom(ConstantLattice): + pass + + +@dataclasses.dataclass +class CTop(ConstantLattice): + pass + + +@dataclasses.dataclass +class CInt(ConstantLattice): + value: Optional[int] = None + + def has_value(self) -> bool: + return self.value is not None + + +def union(self: ConstantLattice, other: ConstantLattice) -> ConstantLattice: + if isinstance(self, CBottom): + return other + if isinstance(self, CTop): + return self + if isinstance(self, CInt) and isinstance(other, CInt): + return self if self.value == other.value else CInt() + return CBottom() + + +@dataclasses.dataclass +class SCCP: + fn: IRFunction + instr_type: dict[Instr, ConstantLattice] = dataclasses.field(init=False, default_factory=dict) + block_executable: set[Block] = dataclasses.field(init=False, default_factory=set) + instr_uses: dict[Instr, set[Instr]] = dataclasses.field(init=False, default_factory=dict) + + def type_of(self, instr: Instr) -> ConstantLattice: + result = self.instr_type.get(instr) + if result is not None: + return result + result = self.instr_type[instr] = CBottom() + return result + + def run(self) -> dict[Instr, ConstantLattice]: + block_worklist: list[Block] = [self.fn.cfg.entry] + instr_worklist: list[Instr] = [] + + while block_worklist or instr_worklist: + if instr_worklist and (instr := instr_worklist.pop(0)): + if isinstance(instr, HasOperands): + for operand in instr.operands: + if operand not in self.instr_uses: + self.instr_uses[operand] = set() + self.instr_uses[operand].add(instr) + new_type: ConstantLattice = CBottom() + if isinstance(instr, Const): + if isinstance(instr.value, Int): + new_type = CInt(instr.value.value) + elif isinstance(instr, Return): + pass + elif isinstance(instr, IntAdd): + match (self.type_of(instr.operands[0]), self.type_of(instr.operands[1])): + case (CInt(int(l)), CInt(int(r))): + new_type = CInt(l + r) + case (CInt(_), CInt(_)): + new_type = CInt() + else: + raise NotImplementedError(f"SCCP {instr}") + old_type = self.type_of(instr) + if union(old_type, new_type) != old_type: + self.instr_type[instr] = new_type + for use in self.instr_uses.get(instr, set()): + instr_worklist.append(use) + if block_worklist and (block := block_worklist.pop(0)): + if block not in self.block_executable: + self.block_executable.add(block) + instr_worklist.extend(block.instrs) + + return self.instr_type + + class IRTests(unittest.TestCase): def _parse(self, source: str) -> Object: return parse(tokenize(source)) @@ -1155,6 +1240,35 @@ def test_dom(self) -> None: ) +class SCCPTests(unittest.TestCase): + def _parse(self, source: str) -> Object: + return parse(tokenize(source)) + + def test_int(self) -> None: + compiler = Compiler() + compiler.compile_body({}, Int(1)) + analysis = SCCP(compiler.fn) + result = analysis.run() + entry = compiler.fn.cfg.entry + self.assertEqual(result, {entry.instrs[0]: CInt(1), entry.instrs[1]: CBottom()}) + + def test_int_add(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("1 + 2")) + analysis = SCCP(compiler.fn) + result = analysis.run() + entry = compiler.fn.cfg.entry + self.assertEqual( + result, + { + entry.instrs[0]: CInt(1), + entry.instrs[1]: CInt(2), + entry.instrs[2]: CInt(3), + entry.instrs[3]: CBottom(), + }, + ) + + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() From bec1be1e5d389b142d2f63bba35b99dffca4f8b3 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 12:18:22 -0500 Subject: [PATCH 30/88] SCCP list --- ir.py | 43 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/ir.py b/ir.py index 1556a3b1..78fa6da0 100644 --- a/ir.py +++ b/ir.py @@ -503,6 +503,11 @@ class CTop(ConstantLattice): pass +@dataclasses.dataclass +class CList(ConstantLattice): + pass + + @dataclasses.dataclass class CInt(ConstantLattice): value: Optional[int] = None @@ -514,6 +519,8 @@ def has_value(self) -> bool: def union(self: ConstantLattice, other: ConstantLattice) -> ConstantLattice: if isinstance(self, CBottom): return other + if isinstance(other, Bottom): + return self if isinstance(self, CTop): return self if isinstance(self, CInt) and isinstance(other, CInt): @@ -548,8 +555,11 @@ def run(self) -> dict[Instr, ConstantLattice]: self.instr_uses[operand].add(instr) new_type: ConstantLattice = CBottom() if isinstance(instr, Const): - if isinstance(instr.value, Int): - new_type = CInt(instr.value.value) + value = instr.value + if isinstance(value, Int): + new_type = CInt(value.value) + if isinstance(value, List): + new_type = CList() elif isinstance(instr, Return): pass elif isinstance(instr, IntAdd): @@ -558,6 +568,9 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CInt(l + r) case (CInt(_), CInt(_)): new_type = CInt() + elif isinstance(instr, ListCons): + if isinstance(self.type_of(instr.operands[1]), CList): + new_type = CList() else: raise NotImplementedError(f"SCCP {instr}") old_type = self.type_of(instr) @@ -1254,7 +1267,7 @@ def test_int(self) -> None: def test_int_add(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("1 + 2")) + compiler.compile_body({}, self._parse("1 + 2 + 3")) analysis = SCCP(compiler.fn) result = analysis.run() entry = compiler.fn.cfg.entry @@ -1264,10 +1277,32 @@ def test_int_add(self) -> None: entry.instrs[0]: CInt(1), entry.instrs[1]: CInt(2), entry.instrs[2]: CInt(3), - entry.instrs[3]: CBottom(), + entry.instrs[3]: CInt(5), + entry.instrs[4]: CInt(6), + entry.instrs[5]: CBottom(), }, ) + def test_empty_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("[]")) + analysis = SCCP(compiler.fn) + analysis.run() + return_instr = compiler.fn.cfg.entry.instrs[-1] + self.assertIsInstance(return_instr, Return) + returned = return_instr.operands[0] + self.assertEqual(analysis.instr_type[returned], CList()) + + def test_const_list(self) -> None: + compiler = Compiler() + compiler.compile_body({}, self._parse("[1, 2]")) + analysis = SCCP(compiler.fn) + analysis.run() + return_instr = compiler.fn.cfg.entry.instrs[-1] + self.assertIsInstance(return_instr, Return) + returned = return_instr.operands[0] + self.assertEqual(analysis.instr_type[returned], CList()) + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 0b66b8bd052ce3bb37c2198de0ccaed839e94f01 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 12:19:04 -0500 Subject: [PATCH 31/88] Fix mypy --- ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index 78fa6da0..342b3f85 100644 --- a/ir.py +++ b/ir.py @@ -519,7 +519,7 @@ def has_value(self) -> bool: def union(self: ConstantLattice, other: ConstantLattice) -> ConstantLattice: if isinstance(self, CBottom): return other - if isinstance(other, Bottom): + if isinstance(other, CBottom): return self if isinstance(self, CTop): return self @@ -1290,6 +1290,7 @@ def test_empty_list(self) -> None: analysis.run() return_instr = compiler.fn.cfg.entry.instrs[-1] self.assertIsInstance(return_instr, Return) + assert isinstance(return_instr, Return) returned = return_instr.operands[0] self.assertEqual(analysis.instr_type[returned], CList()) @@ -1300,6 +1301,7 @@ def test_const_list(self) -> None: analysis.run() return_instr = compiler.fn.cfg.entry.instrs[-1] self.assertIsInstance(return_instr, Return) + assert isinstance(return_instr, Return) returned = return_instr.operands[0] self.assertEqual(analysis.instr_type[returned], CList()) From 528c01a838341747b5e1ae7d488c6c75c75501b9 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:20:15 -0500 Subject: [PATCH 32/88] Start thinking about RefineType --- ir.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ir.py b/ir.py index 342b3f85..4e18851d 100644 --- a/ir.py +++ b/ir.py @@ -118,6 +118,16 @@ class IntLess(HasOperands): pass +@dataclasses.dataclass(init=False, eq=False) +class RefineType(HasOperands): + def __init__(self, value: Instr, ty: ConstantLattice) -> None: + self.operands = [value] + self.ty = ty + + def to_string(self, gvn: InstrId) -> str: + return f"{type(self).__name__}<{self.ty.__class__.__name__}> " + ", ".join(f"v{gvn[op]}" for op in self.operands) + + @dataclasses.dataclass(init=False, eq=False) class IsNumEqualWord(HasOperands): expected: int @@ -378,6 +388,7 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success self.emit(CondBranch(is_list, is_list_block, fallthrough)) self.block = is_list_block updates = {} + # the_list = self.emit(RefineType(param, CList())) the_list = param for i, pattern_item in enumerate(pattern.items): assert not isinstance(pattern_item, Spread) From d9d0e72e94600ea12181f524fbf43c6ad82920c7 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:20:22 -0500 Subject: [PATCH 33/88] Add union-find infra --- ir.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ir.py b/ir.py index 4e18851d..401c65c4 100644 --- a/ir.py +++ b/ir.py @@ -52,6 +52,19 @@ def __getitem__(self, instr: Instr) -> int: @dataclasses.dataclass(eq=False) class Instr: + forwarded: Optional[Instr] = dataclasses.field(init=False, default=None) + + def find(self) -> Instr: + result = self + while True: + it = result.forwarded + if it is None: + return result + result = it + + def make_equal_to(self, other: Instr) -> None: + self.find().forwarded = other + def __hash__(self) -> int: return id(self) From cf22dc16f38f5e6d0f8bee08be4ac6737360f1ec Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:31:13 -0500 Subject: [PATCH 34/88] Make MatchFail a terminator --- ir.py | 141 +++++++++++++++++++++++++++++----------------------------- 1 file changed, 71 insertions(+), 70 deletions(-) diff --git a/ir.py b/ir.py index 401c65c4..cdcc7f59 100644 --- a/ir.py +++ b/ir.py @@ -92,11 +92,6 @@ def to_string(self, gvn: InstrId) -> str: return f"{type(self).__name__}<{self.idx}; {self.name}>" -@dataclasses.dataclass(eq=False) -class MatchFail(Instr): - pass - - @dataclasses.dataclass(eq=False) class HasOperands(Instr): operands: list[Instr] = dataclasses.field(init=False, default_factory=list) @@ -203,6 +198,12 @@ def succs(self) -> tuple[Block, ...]: raise NotImplementedError("succs") +@dataclasses.dataclass(eq=False) +class MatchFail(Control): + def succs(self) -> tuple[Block, ...]: + return () + + @dataclasses.dataclass(eq=False) class NewClosure(HasOperands): fn: IRFunction @@ -870,7 +871,7 @@ def test_match_no_cases(self) -> None: Jump bb1 } bb1 { - v2 = MatchFail + MatchFail } }""", ) @@ -888,17 +889,17 @@ def test_match_one_case(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { - v3 = IsNumEqualWord v1, 1 - CondBranch v3, bb3, bb1 + v2 = IsNumEqualWord v1, 1 + CondBranch v2, bb3, bb1 } bb3 { - v4 = Const<2> - v5 = Const<3> - v6 = IntAdd v4, v5 - Return v6 + v3 = Const<2> + v4 = Const<3> + v5 = IntAdd v3, v4 + Return v5 } }""", ) @@ -916,23 +917,23 @@ def test_match_two_cases(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { - v3 = IsNumEqualWord v1, 1 - CondBranch v3, bb4, bb3 + v2 = IsNumEqualWord v1, 1 + CondBranch v2, bb4, bb3 } bb3 { - v4 = IsNumEqualWord v1, 3 - CondBranch v4, bb5, bb1 + v3 = IsNumEqualWord v1, 3 + CondBranch v3, bb5, bb1 } bb4 { - v5 = Const<2> - Return v5 + v4 = Const<2> + Return v4 } bb5 { - v6 = Const<4> - Return v6 + v5 = Const<4> + Return v5 } }""", ) @@ -950,15 +951,15 @@ def test_match_var(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { Jump bb3 } bb3 { - v3 = Const<1> - v4 = IntAdd v1, v3 - Return v4 + v2 = Const<1> + v3 = IntAdd v1, v2 + Return v3 } }""", ) @@ -976,19 +977,19 @@ def test_match_empty_list(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { - v3 = IsList v1 - CondBranch v3, bb4, bb1 + v2 = IsList v1 + CondBranch v2, bb4, bb1 } bb3 { - v4 = Const<1> - Return v4 + v3 = Const<1> + Return v3 } bb4 { - v5 = IsEmptyList v1 - CondBranch v5, bb3, bb1 + v4 = IsEmptyList v1 + CondBranch v4, bb3, bb1 } }""", ) @@ -1006,29 +1007,29 @@ def test_match_one_item_list(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { - v3 = IsList v1 - CondBranch v3, bb4, bb1 + v2 = IsList v1 + CondBranch v2, bb4, bb1 } bb3 { - v4 = Const<1> - v5 = IntAdd v6, v4 - Return v5 + v3 = Const<1> + v4 = IntAdd v5, v3 + Return v4 } bb4 { - v7 = IsEmptyList v1 - CondBranch v7, bb1, bb5 + v6 = IsEmptyList v1 + CondBranch v6, bb1, bb5 } bb5 { - v6 = ListFirst v1 + v5 = ListFirst v1 Jump bb6 } bb6 { - v8 = ListRest v1 - v9 = IsEmptyList v8 - CondBranch v9, bb3, bb1 + v7 = ListRest v1 + v8 = IsEmptyList v7 + CondBranch v8, bb3, bb1 } }""", ) @@ -1046,37 +1047,37 @@ def test_match_two_item_list(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { - v3 = IsList v1 - CondBranch v3, bb4, bb1 + v2 = IsList v1 + CondBranch v2, bb4, bb1 } bb3 { - v4 = IntAdd v5, v6 - Return v4 + v3 = IntAdd v4, v5 + Return v3 } bb4 { - v7 = IsEmptyList v1 - CondBranch v7, bb1, bb5 + v6 = IsEmptyList v1 + CondBranch v6, bb1, bb5 } bb5 { - v5 = ListFirst v1 + v4 = ListFirst v1 Jump bb6 } bb6 { - v8 = ListRest v1 - v9 = IsEmptyList v8 - CondBranch v9, bb1, bb7 + v7 = ListRest v1 + v8 = IsEmptyList v7 + CondBranch v8, bb1, bb7 } bb7 { - v6 = ListFirst v8 + v5 = ListFirst v7 Jump bb8 } bb8 { - v10 = ListRest v8 - v11 = IsEmptyList v10 - CondBranch v11, bb3, bb1 + v9 = ListRest v7 + v10 = IsEmptyList v9 + CondBranch v10, bb3, bb1 } }""", ) @@ -1122,25 +1123,25 @@ def test_recursive_call(self) -> None: Jump bb2 } bb1 { - v2 = MatchFail + MatchFail } bb2 { - v3 = IsNumEqualWord v1, 0 - CondBranch v3, bb4, bb3 + v2 = IsNumEqualWord v1, 0 + CondBranch v2, bb4, bb3 } bb3 { Jump bb5 } bb4 { - v4 = Const<1> - Return v4 + v3 = Const<1> + Return v3 } bb5 { - v5 = Const<1> - v6 = IntSub v1, v5 - v7 = Call v0, v6 - v8 = IntMul v1, v7 - Return v8 + v4 = Const<1> + v5 = IntSub v1, v4 + v6 = Call v0, v5 + v7 = IntMul v1, v6 + Return v7 } }""", ) From 1fd47fd64312c2d0cebb3480f0d2b17b9a1d08da Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:50:38 -0500 Subject: [PATCH 35/88] Print blocks in RPO --- ir.py | 125 ++++++++++++++++++++++++++++------------------------------ 1 file changed, 61 insertions(+), 64 deletions(-) diff --git a/ir.py b/ir.py index cdcc7f59..4d9e949d 100644 --- a/ir.py +++ b/ir.py @@ -280,19 +280,22 @@ def succs(self) -> tuple[Block, ...]: class CFG: blocks: list[Block] = dataclasses.field(init=False, default_factory=list) entry: Block = dataclasses.field(init=False) + next_block_id: int = 0 def __init__(self) -> None: self.blocks = [] + self.next_block_id = 0 self.entry = self.new_block() def new_block(self) -> Block: - result = Block(len(self.blocks)) + result = Block(self.next_block_id) + self.next_block_id += 1 self.blocks.append(result) return result def to_string(self, fn: IRFunction, gvn: InstrId) -> str: result = "" - for block in self.blocks: + for block in self.rpo(): result += f" {block.name()} {{\n" for instr in block.instrs: if isinstance(instr, Control): @@ -888,13 +891,13 @@ def test_match_one_case(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { v2 = IsNumEqualWord v1, 1 CondBranch v2, bb3, bb1 } + bb1 { + MatchFail + } bb3 { v3 = Const<2> v4 = Const<3> @@ -916,9 +919,6 @@ def test_match_two_cases(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { v2 = IsNumEqualWord v1, 1 CondBranch v2, bb4, bb3 @@ -927,12 +927,15 @@ def test_match_two_cases(self) -> None: v3 = IsNumEqualWord v1, 3 CondBranch v3, bb5, bb1 } - bb4 { - v4 = Const<2> - Return v4 + bb1 { + MatchFail } bb5 { - v5 = Const<4> + v4 = Const<4> + Return v4 + } + bb4 { + v5 = Const<2> Return v5 } }""", @@ -950,9 +953,6 @@ def test_match_var(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { Jump bb3 } @@ -976,20 +976,20 @@ def test_match_empty_list(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { v2 = IsList v1 CondBranch v2, bb4, bb1 } - bb3 { - v3 = Const<1> - Return v3 - } bb4 { - v4 = IsEmptyList v1 - CondBranch v4, bb3, bb1 + v3 = IsEmptyList v1 + CondBranch v3, bb3, bb1 + } + bb1 { + MatchFail + } + bb3 { + v4 = Const<1> + Return v4 } }""", ) @@ -1006,30 +1006,30 @@ def test_match_one_item_list(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { v2 = IsList v1 CondBranch v2, bb4, bb1 } - bb3 { - v3 = Const<1> - v4 = IntAdd v5, v3 - Return v4 - } bb4 { - v6 = IsEmptyList v1 - CondBranch v6, bb1, bb5 + v3 = IsEmptyList v1 + CondBranch v3, bb1, bb5 } bb5 { - v5 = ListFirst v1 + v4 = ListFirst v1 Jump bb6 } bb6 { - v7 = ListRest v1 - v8 = IsEmptyList v7 - CondBranch v8, bb3, bb1 + v5 = ListRest v1 + v6 = IsEmptyList v5 + CondBranch v6, bb3, bb1 + } + bb3 { + v7 = Const<1> + v8 = IntAdd v4, v7 + Return v8 + } + bb1 { + MatchFail } }""", ) @@ -1046,38 +1046,38 @@ def test_match_two_item_list(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { v2 = IsList v1 CondBranch v2, bb4, bb1 } - bb3 { - v3 = IntAdd v4, v5 - Return v3 - } bb4 { - v6 = IsEmptyList v1 - CondBranch v6, bb1, bb5 + v3 = IsEmptyList v1 + CondBranch v3, bb1, bb5 } bb5 { v4 = ListFirst v1 Jump bb6 } bb6 { - v7 = ListRest v1 - v8 = IsEmptyList v7 - CondBranch v8, bb1, bb7 + v5 = ListRest v1 + v6 = IsEmptyList v5 + CondBranch v6, bb1, bb7 } bb7 { - v5 = ListFirst v7 + v7 = ListFirst v5 Jump bb8 } bb8 { - v9 = ListRest v7 - v10 = IsEmptyList v9 - CondBranch v10, bb3, bb1 + v8 = ListRest v5 + v9 = IsEmptyList v8 + CondBranch v9, bb3, bb1 + } + bb3 { + v10 = IntAdd v4, v7 + Return v10 + } + bb1 { + MatchFail } }""", ) @@ -1122,9 +1122,6 @@ def test_recursive_call(self) -> None: v1 = Param<1; arg_0> Jump bb2 } - bb1 { - MatchFail - } bb2 { v2 = IsNumEqualWord v1, 0 CondBranch v2, bb4, bb3 @@ -1132,15 +1129,15 @@ def test_recursive_call(self) -> None: bb3 { Jump bb5 } - bb4 { + bb5 { v3 = Const<1> - Return v3 + v4 = IntSub v1, v3 + v5 = Call v0, v4 + v6 = IntMul v1, v5 + Return v6 } - bb5 { - v4 = Const<1> - v5 = IntSub v1, v4 - v6 = Call v0, v5 - v7 = IntMul v1, v6 + bb4 { + v7 = Const<1> Return v7 } }""", From d0a1b1494f81ad803c65e049260991f082b78a92 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:52:22 -0500 Subject: [PATCH 36/88] Add CleanCFG --- ir.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/ir.py b/ir.py index 4d9e949d..20f6ab8b 100644 --- a/ir.py +++ b/ir.py @@ -614,6 +614,43 @@ def run(self) -> dict[Instr, ConstantLattice]: return self.instr_type +@dataclasses.dataclass +class CleanCFG: + fn: IRFunction + + def run(self) -> None: + changed = True + while changed: + changed = False + for block in self.fn.cfg.rpo(): + if not block.instrs: + # Ignore transient empty blocks. + continue + # Keep working on the current block until no further changes are made. + while self.absorb_dst_block(block): + pass + changed = self.remove_unreachable_blocks() + + def absorb_dst_block(self, block: Block) -> bool: + terminator = block.terminator() + if not isinstance(terminator, Jump): + return False + target = terminator.target + if target == block: + return False + preds = self.fn.cfg.preds() + if len(preds[target]) > 1: + return False + block.instrs.pop(-1) + block.instrs.extend(target.instrs) + target.instrs.clear() + # No Phi to fix up + return True + + def remove_unreachable_blocks(self) -> bool: + self.fn.cfg.blocks = self.fn.cfg.rpo() + + class IRTests(unittest.TestCase): def _parse(self, source: str) -> Object: return parse(tokenize(source)) @@ -876,6 +913,18 @@ def test_match_no_cases(self) -> None: bb1 { MatchFail } +}""", + ) + CleanCFG(compiler.fns[1]).run() + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + MatchFail + } }""", ) @@ -1140,6 +1189,30 @@ def test_recursive_call(self) -> None: v7 = Const<1> Return v7 } +}""", + ) + CleanCFG(compiler.fns[1]).run() + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + v2 = IsNumEqualWord v1, 0 + CondBranch v2, bb4, bb3 + } + bb3 { + v3 = Const<1> + v4 = IntSub v1, v3 + v5 = Call v0, v4 + v6 = IntMul v1, v5 + Return v6 + } + bb4 { + v7 = Const<1> + Return v7 + } }""", ) From a77b71e13fdc009037ab65d8e7e4da4fa152cbf5 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:56:48 -0500 Subject: [PATCH 37/88] Return true if remove_unreachable_blocks changed --- ir.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ir.py b/ir.py index 20f6ab8b..c516c090 100644 --- a/ir.py +++ b/ir.py @@ -648,7 +648,9 @@ def absorb_dst_block(self, block: Block) -> bool: return True def remove_unreachable_blocks(self) -> bool: + num_blocks = len(self.fn.cfg.blocks) self.fn.cfg.blocks = self.fn.cfg.rpo() + return len(self.fn.cfg.blocks) != num_blocks class IRTests(unittest.TestCase): From a5517dab0500910d56e81cdbac76930e30134612 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 14:58:31 -0500 Subject: [PATCH 38/88] . --- ir.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ir.py b/ir.py index c516c090..1d57cbdb 100644 --- a/ir.py +++ b/ir.py @@ -1130,6 +1130,42 @@ def test_match_two_item_list(self) -> None: bb1 { MatchFail } +}""", + ) + CleanCFG(compiler.fns[1]).run() + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + v2 = IsList v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = IsEmptyList v1 + CondBranch v3, bb1, bb5 + } + bb5 { + v4 = ListFirst v1 + v5 = ListRest v1 + v6 = IsEmptyList v5 + CondBranch v6, bb1, bb7 + } + bb7 { + v7 = ListFirst v5 + v8 = ListRest v5 + v9 = IsEmptyList v8 + CondBranch v9, bb3, bb1 + } + bb3 { + v10 = IntAdd v4, v7 + Return v10 + } + bb1 { + MatchFail + } }""", ) From 2a21cb37154e48ade8e5fcd0da1207636eff7960 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 15:09:46 -0500 Subject: [PATCH 39/88] . --- ir.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ir.py b/ir.py index 1d57cbdb..c07facb3 100644 --- a/ir.py +++ b/ir.py @@ -137,7 +137,7 @@ def to_string(self, gvn: InstrId) -> str: @dataclasses.dataclass(init=False, eq=False) -class IsNumEqualWord(HasOperands): +class IsIntEqualWord(HasOperands): expected: int def __init__(self, value: Instr, expected: int) -> None: @@ -393,7 +393,7 @@ def emit(self, instr: Instr) -> Instr: def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success: Block, fallthrough: Block) -> Env: if isinstance(pattern, Int): - cond = self.emit(IsNumEqualWord(param, pattern.value)) + cond = self.emit(IsIntEqualWord(param, pattern.value)) self.emit(CondBranch(cond, success, fallthrough)) return {} if isinstance(pattern, Var): @@ -943,7 +943,7 @@ def test_match_one_case(self) -> None: Jump bb2 } bb2 { - v2 = IsNumEqualWord v1, 1 + v2 = IsIntEqualWord v1, 1 CondBranch v2, bb3, bb1 } bb1 { @@ -971,11 +971,11 @@ def test_match_two_cases(self) -> None: Jump bb2 } bb2 { - v2 = IsNumEqualWord v1, 1 + v2 = IsIntEqualWord v1, 1 CondBranch v2, bb4, bb3 } bb3 { - v3 = IsNumEqualWord v1, 3 + v3 = IsIntEqualWord v1, 3 CondBranch v3, bb5, bb1 } bb1 { @@ -1210,7 +1210,7 @@ def test_recursive_call(self) -> None: Jump bb2 } bb2 { - v2 = IsNumEqualWord v1, 0 + v2 = IsIntEqualWord v1, 0 CondBranch v2, bb4, bb3 } bb3 { @@ -1237,7 +1237,7 @@ def test_recursive_call(self) -> None: bb0 { v0 = Param<0; $clo> v1 = Param<1; arg_0> - v2 = IsNumEqualWord v1, 0 + v2 = IsIntEqualWord v1, 0 CondBranch v2, bb4, bb3 } bb3 { From 4b56aaec0f6340c09c6842d7920dc47ea6b3ea8a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 15:17:49 -0500 Subject: [PATCH 40/88] Support NewClosure and Call in SCCP --- ir.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ir.py b/ir.py index c07facb3..f7bd1a26 100644 --- a/ir.py +++ b/ir.py @@ -544,6 +544,11 @@ def has_value(self) -> bool: return self.value is not None +@dataclasses.dataclass +class CClo(ConstantLattice): + value: Optional[IRFunction] = None + + def union(self: ConstantLattice, other: ConstantLattice) -> ConstantLattice: if isinstance(self, CBottom): return other @@ -599,6 +604,10 @@ def run(self) -> dict[Instr, ConstantLattice]: elif isinstance(instr, ListCons): if isinstance(self.type_of(instr.operands[1]), CList): new_type = CList() + elif isinstance(instr, NewClosure): + new_type = CClo(instr.fn) + elif isinstance(instr, Call): + new_type = CTop() else: raise NotImplementedError(f"SCCP {instr}") old_type = self.type_of(instr) @@ -1269,6 +1278,10 @@ def test_apply_anonymous_function(self) -> None: } }""", ) + analysis = SCCP(compiler.fns[0]) + analysis.run() + entry = compiler.fns[0].cfg.entry + self.assertEqual(analysis.instr_type[entry.instrs[0]], CClo(compiler.fns[1])) class RPOTests(unittest.TestCase): From 5408f74c95db6ea192ae0d0f8608e5481a46b076 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 15:17:53 -0500 Subject: [PATCH 41/88] ruff format --- ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index f7bd1a26..76a01bc3 100644 --- a/ir.py +++ b/ir.py @@ -133,7 +133,9 @@ def __init__(self, value: Instr, ty: ConstantLattice) -> None: self.ty = ty def to_string(self, gvn: InstrId) -> str: - return f"{type(self).__name__}<{self.ty.__class__.__name__}> " + ", ".join(f"v{gvn[op]}" for op in self.operands) + return f"{type(self).__name__}<{self.ty.__class__.__name__}> " + ", ".join( + f"v{gvn[op]}" for op in self.operands + ) @dataclasses.dataclass(init=False, eq=False) From 7cf06421dcb568011eccabb07b495f3d2565efc9 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 10:54:42 -0500 Subject: [PATCH 42/88] . --- ir.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ir.py b/ir.py index 76a01bc3..de3c83c6 100644 --- a/ir.py +++ b/ir.py @@ -49,6 +49,9 @@ def __getitem__(self, instr: Instr) -> int: self.data[instr] = id return id + def name(self, instr: Instr) -> str: + return f"v{self[instr]}" + @dataclasses.dataclass(eq=False) class Instr: @@ -361,6 +364,36 @@ def to_string(self, gvn: InstrId) -> str: result += self.cfg.to_string(self, gvn) return result + "}" + def to_c(self) -> str: + gvn = InstrId() + return self._to_c(self.cfg.entry, gvn, self.cfg.doms()) + + def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: + if isinstance(instr, Const): + if isinstance(instr.value, Int): + return f"new_int({instr.value.value})" + if isinstance(instr, IntAdd): + operands = ", ".join(gvn.name(op) for op in instr.operands) + return f"int_add({operands})" + if isinstance(instr, Param): + return f"param{instr.idx}" + if isinstance(instr, NewClosure): + operands = ", ".join([f"fn{instr.fn.id}", *(gvn.name(op) for op in instr.operands)]) + return f"new_closure({operands})" + raise NotImplementedError(type(instr)) + + def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: + result = f"Object *fn{self.id}() {{\n" + for instr in block.instrs: + if isinstance(instr, Control): break + rhs = self._instr_to_c(instr, gvn, doms) + result += f"Object *{gvn.name(instr)} = {rhs};\n" + assert isinstance(instr, Control) + if isinstance(instr, Return): + result += f"return {gvn.name(instr.operands[0])};\n" + result += "}" + return result + class Compiler: def __init__(self) -> None: @@ -517,6 +550,9 @@ def compile(self, env: Env, exp: Object) -> Instr: return self.compile_function(env, exp, func_name=None) raise NotImplementedError(f"exp {type(exp)} {exp}") + def to_c(self) -> str: + return "\n".join(fn.to_c() for fn in self.fns) + @dataclasses.dataclass class ConstantLattice: From 518dde2f23c91585d8e2b6e605c3af70bff81d69 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 17:40:05 -0500 Subject: [PATCH 43/88] Use InstrId.name --- ir.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ir.py b/ir.py index de3c83c6..481a6b7c 100644 --- a/ir.py +++ b/ir.py @@ -106,7 +106,7 @@ def to_string(self, gvn: InstrId) -> str: stem = f"{type(self).__name__}" if not self.operands: return stem - return stem + " " + ", ".join(f"v{gvn[op]}" for op in self.operands) + return stem + " " + ", ".join(f"{gvn.name(op)}" for op in self.operands) @dataclasses.dataclass(init=False, eq=False) @@ -137,7 +137,7 @@ def __init__(self, value: Instr, ty: ConstantLattice) -> None: def to_string(self, gvn: InstrId) -> str: return f"{type(self).__name__}<{self.ty.__class__.__name__}> " + ", ".join( - f"v{gvn[op]}" for op in self.operands + f"{gvn.name(op)}" for op in self.operands ) @@ -164,7 +164,7 @@ def __init__(self, closure: Instr, idx: int, name: str) -> None: self.name = name def to_string(self, gvn: InstrId) -> str: - return f"{type(self).__name__}<{self.idx}; {self.name}> v{gvn[self.operands[0]]}" + return f"{type(self).__name__}<{self.idx}; {self.name}> {gvn.name(self.operands[0])}" @dataclasses.dataclass(init=False, eq=False) @@ -221,7 +221,7 @@ def to_string(self, gvn: InstrId) -> str: stem = f"{type(self).__name__}<{self.fn.name()}>" if not self.operands: return stem - return f"{stem} " + ", ".join(f"v{gvn[op]}" for op in self.operands) + return f"{stem} " + ", ".join(f"{gvn.name(op)}" for op in self.operands) Env = Dict[str, Instr] @@ -306,7 +306,7 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str: if isinstance(instr, Control): result += f" {instr.to_string(gvn)}\n" else: - result += f" v{gvn[instr]} = {instr.to_string(gvn)}\n" + result += f" {gvn.name(instr)} = {instr.to_string(gvn)}\n" result += " }\n" return result From b0e90da3f3a828e69598be3b4916b4a0b58a698b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 17:40:22 -0500 Subject: [PATCH 44/88] ruff format --- ir.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index 481a6b7c..23beb268 100644 --- a/ir.py +++ b/ir.py @@ -385,7 +385,8 @@ def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: result = f"Object *fn{self.id}() {{\n" for instr in block.instrs: - if isinstance(instr, Control): break + if isinstance(instr, Control): + break rhs = self._instr_to_c(instr, gvn, doms) result += f"Object *{gvn.name(instr)} = {rhs};\n" assert isinstance(instr, Control) From da3fb540afb8611614e1b69a7d0769862edad49a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:18:16 -0500 Subject: [PATCH 45/88] Well, very silly simple cases work... --- ir.py | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 101 insertions(+), 6 deletions(-) diff --git a/ir.py b/ir.py index 23beb268..7bd02290 100644 --- a/ir.py +++ b/ir.py @@ -371,24 +371,24 @@ def to_c(self) -> str: def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: if isinstance(instr, Const): if isinstance(instr.value, Int): - return f"new_int({instr.value.value})" + return f"mksmallint({instr.value.value})" if isinstance(instr, IntAdd): operands = ", ".join(gvn.name(op) for op in instr.operands) - return f"int_add({operands})" + return f"num_add({operands})" if isinstance(instr, Param): return f"param{instr.idx}" if isinstance(instr, NewClosure): operands = ", ".join([f"fn{instr.fn.id}", *(gvn.name(op) for op in instr.operands)]) - return f"new_closure({operands})" + return f"mkclosure(heap, {operands})" raise NotImplementedError(type(instr)) def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: - result = f"Object *fn{self.id}() {{\n" + result = f"struct object *fn{self.id}() {{\n" for instr in block.instrs: if isinstance(instr, Control): break rhs = self._instr_to_c(instr, gvn, doms) - result += f"Object *{gvn.name(instr)} = {rhs};\n" + result += f"struct object *{gvn.name(instr)} = {rhs};\n" assert isinstance(instr, Control) if isinstance(instr, Return): result += f"return {gvn.name(instr.operands[0])};\n" @@ -399,7 +399,7 @@ def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> st class Compiler: def __init__(self) -> None: self.fns: list[IRFunction] = [] - entry = self.new_function([]) + self.entry = entry = self.new_function([]) self.gensym_counter: int = 0 self.fn: IRFunction = entry self.block: Block = entry.cfg.entry @@ -1491,6 +1491,101 @@ def test_const_list(self) -> None: self.assertEqual(analysis.instr_type[returned], CList()) +def compile_to_binary(source: str, memory: int, debug: bool) -> str: + import shlex + import subprocess + import sysconfig + import tempfile + + program = parse(tokenize(source)) + compiler = Compiler() + compiler.compile_body({}, program) + c_code = compiler.to_c() + dirname = os.path.dirname(__file__) + with tempfile.NamedTemporaryFile(mode="w", suffix=".c", delete=False) as c_file: + constants = [ + ("uword", "kKiB", 1024), + ("uword", "kMiB", "kKiB * kKiB"), + ("uword", "kGiB", "kKiB * kKiB * kKiB"), + ("uword", "kPageSize", "4 * kKiB"), + ("uword", "kSmallIntTagBits", 1), + ("uword", "kPrimaryTagBits", 3), + ("uword", "kObjectAlignmentLog2", 3), # bits + ("uword", "kObjectAlignment", "1ULL << kObjectAlignmentLog2"), + ("uword", "kImmediateTagBits", 5), + ("uword", "kSmallIntTagMask", "(1ULL << kSmallIntTagBits) - 1"), + ("uword", "kPrimaryTagMask", "(1ULL << kPrimaryTagBits) - 1"), + ("uword", "kImmediateTagMask", "(1ULL << kImmediateTagBits) - 1"), + ("uword", "kWordSize", "sizeof(word)"), + ("uword", "kMaxSmallStringLength", "kWordSize - 1"), + ("uword", "kBitsPerByte", 8), + # Up to the five least significant bits are used to tag the object's layout. + # The three low bits make up a primary tag, used to differentiate gc_obj + # from immediate objects. All even tags map to SmallInt, which is + # optimized by checking only the lowest bit for parity. + ("uword", "kSmallIntTag", 0), # 0b****0 + ("uword", "kHeapObjectTag", 1), # 0b**001 + ("uword", "kEmptyListTag", 5), # 0b00101 + ("uword", "kHoleTag", 7), # 0b00111 + ("uword", "kSmallStringTag", 13), # 0b01101 + ("uword", "kVariantTag", 15), # 0b01111 + # TODO(max): Fill in 21 + # TODO(max): Fill in 23 + # TODO(max): Fill in 29 + # TODO(max): Fill in 31 + ("uword", "kBitsPerPointer", "kBitsPerByte * kWordSize"), + ("word", "kSmallIntBits", "kBitsPerPointer - kSmallIntTagBits"), + ("word", "kSmallIntMinValue", "-(((word)1) << (kSmallIntBits - 1))"), + ("word", "kSmallIntMaxValue", "(((word)1) << (kSmallIntBits - 1)) - 1"), + ] + for type_, name, value in constants: + print(f"#define {name} ({type_})({value})", file=c_file) + # The runtime is in the same directory as this file + with open(os.path.join(dirname, "runtime.c"), "r") as runtime: + c_file.write(runtime.read()) + c_file.write("\n") + c_file.write(c_code) + c_file.write("\n") + # The platform is in the same directory as this file + print( + f""" + +const char* variant_names[] = {{ + "UNDEF", +}}; +const char* record_keys[] = {{ + "UNDEF", +}}; +int main() {{ + struct space space = make_space(MEMORY_SIZE); + init_heap(heap, space); + HANDLES(); + GC_HANDLE(struct object*, result, {compiler.entry.name()}()); + println(result); + destroy_space(space); + return 0; +}} +""", + file=c_file, + ) + cc = os.environ.get("CC", "tcc") + with tempfile.NamedTemporaryFile(mode="w", suffix=".out", delete=False) as out_file: + subprocess.run([cc, "-o", out_file.name, c_file.name], check=True) + return out_file.name + + +class CompilerEndToEndTests(unittest.TestCase): + def _run(self, code: str) -> str: + import subprocess + + binary = compile_to_binary(code, memory=4096, debug=True) + result = subprocess.run([binary], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) + return result.stdout + + def test_int(self) -> None: + self.assertEqual(self._run("1"), "1\n") + + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() From 9274f0066541f78b0c7f4d5169fde9cf899d327f Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:18:40 -0500 Subject: [PATCH 46/88] . --- ir.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ir.py b/ir.py index 7bd02290..f7e472cb 100644 --- a/ir.py +++ b/ir.py @@ -1585,6 +1585,9 @@ def _run(self, code: str) -> str: def test_int(self) -> None: self.assertEqual(self._run("1"), "1\n") + def test_int(self) -> None: + self.assertEqual(self._run("1 + 2"), "3\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From da28a4c866a81f3277b59f2183dab8e4a4c65677 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:20:23 -0500 Subject: [PATCH 47/88] Use handles (remove with analysis later) --- ir.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index f7e472cb..5098430c 100644 --- a/ir.py +++ b/ir.py @@ -384,11 +384,12 @@ def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: result = f"struct object *fn{self.id}() {{\n" + result += "HANDLES();\n" for instr in block.instrs: if isinstance(instr, Control): break rhs = self._instr_to_c(instr, gvn, doms) - result += f"struct object *{gvn.name(instr)} = {rhs};\n" + result += f"GC_HANDLE(struct object *, {gvn.name(instr)}, {rhs});\n" assert isinstance(instr, Control) if isinstance(instr, Return): result += f"return {gvn.name(instr.operands[0])};\n" @@ -1585,7 +1586,7 @@ def _run(self, code: str) -> str: def test_int(self) -> None: self.assertEqual(self._run("1"), "1\n") - def test_int(self) -> None: + def test_int_add(self) -> None: self.assertEqual(self._run("1 + 2"), "3\n") From 33ad68e7e0e4e1b624edaf5b437046333deccb79 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:28:23 -0500 Subject: [PATCH 48/88] . --- ir.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ir.py b/ir.py index 5098430c..54ae8bb8 100644 --- a/ir.py +++ b/ir.py @@ -366,7 +366,13 @@ def to_string(self, gvn: InstrId) -> str: def to_c(self) -> str: gvn = InstrId() - return self._to_c(self.cfg.entry, gvn, self.cfg.doms()) + params = ", ".join(f"struct object *{param}" for param in self.params) + result = f"struct object *fn{self.id}({params}) {{\n" + result += "HANDLES();\n" + for param in self.params: + result += f"GC_PROTECT({param});\n" + result += self._to_c(self.cfg.entry, gvn, self.cfg.doms()) + return result + "}" def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: if isinstance(instr, Const): @@ -376,15 +382,14 @@ def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) operands = ", ".join(gvn.name(op) for op in instr.operands) return f"num_add({operands})" if isinstance(instr, Param): - return f"param{instr.idx}" + return self.params[instr.idx] if isinstance(instr, NewClosure): operands = ", ".join([f"fn{instr.fn.id}", *(gvn.name(op) for op in instr.operands)]) return f"mkclosure(heap, {operands})" raise NotImplementedError(type(instr)) def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: - result = f"struct object *fn{self.id}() {{\n" - result += "HANDLES();\n" + result = "" for instr in block.instrs: if isinstance(instr, Control): break @@ -393,7 +398,8 @@ def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> st assert isinstance(instr, Control) if isinstance(instr, Return): result += f"return {gvn.name(instr.operands[0])};\n" - result += "}" + else: + raise NotImplementedError(instr) return result From 56cc31b072aa8bd80a664d2570e19513502ff6f6 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:45:21 -0500 Subject: [PATCH 49/88] Get a little closer to pattern matching and closures --- ir.py | 54 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/ir.py b/ir.py index 54ae8bb8..51ea10be 100644 --- a/ir.py +++ b/ir.py @@ -365,42 +365,60 @@ def to_string(self, gvn: InstrId) -> str: return result + "}" def to_c(self) -> str: - gvn = InstrId() - params = ", ".join(f"struct object *{param}" for param in self.params) - result = f"struct object *fn{self.id}({params}) {{\n" - result += "HANDLES();\n" - for param in self.params: - result += f"GC_PROTECT({param});\n" - result += self._to_c(self.cfg.entry, gvn, self.cfg.doms()) - return result + "}" + with io.StringIO() as f: + params = ", ".join(f"struct object *{param}" for param in self.params) + f.write(f"struct object *fn{self.id}({params}) {{\n") + f.write("HANDLES();\n") + for param in self.params: + f.write(f"GC_PROTECT({param});\n") + self._to_c(f, self.cfg.entry, InstrId(), self.cfg.doms()) + f.write("}") + return f.getvalue() + return def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: + def _handle(rhs: str) -> str: + return f"GC_HANDLE(struct object *, {gvn.name(instr)}, {rhs});\n" + + def _decl(ty: str, rhs: str) -> str: + return f"{ty} {gvn.name(instr)} = {rhs};\n" + if isinstance(instr, Const): if isinstance(instr.value, Int): - return f"mksmallint({instr.value.value})" + return _handle(f"mksmallint({instr.value.value})") if isinstance(instr, IntAdd): operands = ", ".join(gvn.name(op) for op in instr.operands) - return f"num_add({operands})" + return _handle(f"num_add({operands})") if isinstance(instr, Param): - return self.params[instr.idx] + return _handle(self.params[instr.idx]) if isinstance(instr, NewClosure): operands = ", ".join([f"fn{instr.fn.id}", *(gvn.name(op) for op in instr.operands)]) - return f"mkclosure(heap, {operands})" + return _handle(f"mkclosure(heap, {operands})") + if isinstance(instr, IsIntEqualWord): + return _decl("bool", f"{gvn.name(instr.operands[0])} == mksmallint({instr.expected})") raise NotImplementedError(type(instr)) - def _to_c(self, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: - result = "" + def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: + f.write(f"{block.name()}:;\n") for instr in block.instrs: if isinstance(instr, Control): break - rhs = self._instr_to_c(instr, gvn, doms) - result += f"GC_HANDLE(struct object *, {gvn.name(instr)}, {rhs});\n" + f.write(self._instr_to_c(instr, gvn, doms)) assert isinstance(instr, Control) if isinstance(instr, Return): - result += f"return {gvn.name(instr.operands[0])};\n" + f.write(f"return {gvn.name(instr.operands[0])};\n") + elif isinstance(instr, Jump): + f.write(f"goto {instr.target.name()};\n") + self._to_c(f, instr.target, gvn, doms) + elif isinstance(instr, CondBranch): + f.write(f"if ({gvn.name(instr.operands[0])}) {{ goto {instr.conseq.name()}; }} else {{ goto {instr.alt.name()}; }}\n") + self._to_c(f, instr.conseq, gvn, doms) + self._to_c(f, instr.alt, gvn, doms) + elif isinstance(instr, MatchFail): + f.write("""fprintf(stderr, "no matching cases\\n");\n""") + f.write("abort();\n") else: raise NotImplementedError(instr) - return result class Compiler: From 47ddb7def8a3024b85ba4fb6c04f65c101235335 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:48:57 -0500 Subject: [PATCH 50/88] Forward declare functions --- ir.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index 51ea10be..5e20d478 100644 --- a/ir.py +++ b/ir.py @@ -366,8 +366,7 @@ def to_string(self, gvn: InstrId) -> str: def to_c(self) -> str: with io.StringIO() as f: - params = ", ".join(f"struct object *{param}" for param in self.params) - f.write(f"struct object *fn{self.id}({params}) {{\n") + f.write(f"{self.c_decl()} {{\n") f.write("HANDLES();\n") for param in self.params: f.write(f"GC_PROTECT({param});\n") @@ -376,6 +375,10 @@ def to_c(self) -> str: return f.getvalue() return + def c_decl(self) -> str: + params = ", ".join(f"struct object *{param}" for param in self.params) + return f"struct object *fn{self.id}({params})\n" + def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: def _handle(rhs: str) -> str: return f"GC_HANDLE(struct object *, {gvn.name(instr)}, {rhs});\n" @@ -1569,6 +1572,8 @@ def compile_to_binary(source: str, memory: int, debug: bool) -> str: with open(os.path.join(dirname, "runtime.c"), "r") as runtime: c_file.write(runtime.read()) c_file.write("\n") + for fn in compiler.fns: + c_file.write(fn.c_decl() + ";\n") c_file.write(c_code) c_file.write("\n") # The platform is in the same directory as this file From fb97c504791dc909bed81537307e2a4d815e7e3e Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:49:15 -0500 Subject: [PATCH 51/88] . --- ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index 5e20d478..0b5d2688 100644 --- a/ir.py +++ b/ir.py @@ -414,7 +414,9 @@ def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, se f.write(f"goto {instr.target.name()};\n") self._to_c(f, instr.target, gvn, doms) elif isinstance(instr, CondBranch): - f.write(f"if ({gvn.name(instr.operands[0])}) {{ goto {instr.conseq.name()}; }} else {{ goto {instr.alt.name()}; }}\n") + f.write( + f"if ({gvn.name(instr.operands[0])}) {{ goto {instr.conseq.name()}; }} else {{ goto {instr.alt.name()}; }}\n" + ) self._to_c(f, instr.conseq, gvn, doms) self._to_c(f, instr.alt, gvn, doms) elif isinstance(instr, MatchFail): From 318d1dc8807e71bac7a236fb65ebfbe7bbdc7fc6 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:53:17 -0500 Subject: [PATCH 52/88] Compile closures --- ir.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index 0b5d2688..e242a6eb 100644 --- a/ir.py +++ b/ir.py @@ -395,8 +395,10 @@ def _decl(ty: str, rhs: str) -> str: if isinstance(instr, Param): return _handle(self.params[instr.idx]) if isinstance(instr, NewClosure): - operands = ", ".join([f"fn{instr.fn.id}", *(gvn.name(op) for op in instr.operands)]) - return _handle(f"mkclosure(heap, {operands})") + result = _handle(f"mkclosure(heap, {instr.fn.name()}, {len(instr.operands)})") + for idx, op in enumerate(instr.operands): + result += f"closure_set({gvn.name(op)}, {idx}, {gvn.name(op)});\n" + return result if isinstance(instr, IsIntEqualWord): return _decl("bool", f"{gvn.name(instr.operands[0])} == mksmallint({instr.expected})") raise NotImplementedError(type(instr)) @@ -1620,6 +1622,12 @@ def test_int(self) -> None: def test_int_add(self) -> None: self.assertEqual(self._run("1 + 2"), "3\n") + def test_fun_id(self) -> None: + self.assertEqual(self._run("a -> a"), "\n") + + def test_match_int(self) -> None: + self.assertEqual(self._run("| 1 -> 2"), "\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 427f32e6f6a98d29683d02154f4994a3207e21af Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:56:46 -0500 Subject: [PATCH 53/88] Use OBJECT_HANDLE --- compiler.py | 1 - ir.py | 2 +- runtime.c | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler.py b/compiler.py index 19a31c7a..bbc91d7c 100644 --- a/compiler.py +++ b/compiler.py @@ -494,7 +494,6 @@ def compile_to_string(program: Object, debug: bool) -> str: dirname = os.path.dirname(__file__) with open(os.path.join(dirname, "runtime.c"), "r") as runtime: print(runtime.read(), file=f) - print("#define OBJECT_HANDLE(name, exp) GC_HANDLE(struct object*, name, exp)", file=f) if compiler.record_keys: print("const char* record_keys[] = {", file=f) for key in compiler.record_keys: diff --git a/ir.py b/ir.py index e242a6eb..7c44790b 100644 --- a/ir.py +++ b/ir.py @@ -381,7 +381,7 @@ def c_decl(self) -> str: def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: def _handle(rhs: str) -> str: - return f"GC_HANDLE(struct object *, {gvn.name(instr)}, {rhs});\n" + return f"OBJECT_HANDLE({gvn.name(instr)}, {rhs});\n" def _decl(ty: str, rhs: str) -> str: return f"{ty} {gvn.name(instr)} = {rhs};\n" diff --git a/runtime.c b/runtime.c index 29af2a16..cbf8c9d6 100644 --- a/runtime.c +++ b/runtime.c @@ -706,6 +706,7 @@ void pop_handles(void* local_handles) { #define GC_HANDLE(type, name, val) \ type name = val; \ GC_PROTECT(name) +#define OBJECT_HANDLE(name, exp) GC_HANDLE(struct object*, name, exp) void trace_roots(struct gc_heap* heap, VisitFn visit) { for (struct object*** h = handle_stack; h != handles; h++) { From adc3350e252cb01a4038c2b522145475fa70e640 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 22:57:21 -0500 Subject: [PATCH 54/88] . --- ir.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ir.py b/ir.py index 7c44790b..f710a391 100644 --- a/ir.py +++ b/ir.py @@ -424,6 +424,7 @@ def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, se elif isinstance(instr, MatchFail): f.write("""fprintf(stderr, "no matching cases\\n");\n""") f.write("abort();\n") + f.write("return NULL;\n") # Pacify the C compiler else: raise NotImplementedError(instr) From a2ae5f11a58b5b35c77596c06970e3ed11732d73 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:00:06 -0500 Subject: [PATCH 55/88] Support calls --- ir.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ir.py b/ir.py index f710a391..c777e2ad 100644 --- a/ir.py +++ b/ir.py @@ -386,6 +386,9 @@ def _handle(rhs: str) -> str: def _decl(ty: str, rhs: str) -> str: return f"{ty} {gvn.name(instr)} = {rhs};\n" + def op(idx: int) -> str: + return gvn.name(instr.operands[idx]) + if isinstance(instr, Const): if isinstance(instr.value, Int): return _handle(f"mksmallint({instr.value.value})") @@ -401,6 +404,8 @@ def _decl(ty: str, rhs: str) -> str: return result if isinstance(instr, IsIntEqualWord): return _decl("bool", f"{gvn.name(instr.operands[0])} == mksmallint({instr.expected})") + if isinstance(instr, Call): + return _handle(f"closure_call({op(0)}, {op(1)})") raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: @@ -1629,6 +1634,9 @@ def test_fun_id(self) -> None: def test_match_int(self) -> None: self.assertEqual(self._run("| 1 -> 2"), "\n") + def test_call_match_int(self) -> None: + self.assertEqual(self._run("(| 1 -> 2) 1"), "2\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 367f0004957db86f49ea6b5a2a7ce88544607f8c Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:00:39 -0500 Subject: [PATCH 56/88] . --- ir.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ir.py b/ir.py index c777e2ad..da219f39 100644 --- a/ir.py +++ b/ir.py @@ -193,7 +193,7 @@ class ListRest(HasOperands): @dataclasses.dataclass(init=False, eq=False) -class Call(HasOperands): +class ClosureCall(HasOperands): pass @@ -404,7 +404,7 @@ def op(idx: int) -> str: return result if isinstance(instr, IsIntEqualWord): return _decl("bool", f"{gvn.name(instr.operands[0])} == mksmallint({instr.expected})") - if isinstance(instr, Call): + if isinstance(instr, ClosureCall): return _handle(f"closure_call({op(0)}, {op(1)})") raise NotImplementedError(type(instr)) @@ -583,7 +583,7 @@ def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, Apply): fn = self.compile(env, exp.func) arg = self.compile(env, exp.arg) - return self.emit(Call(fn, arg)) + return self.emit(ClosureCall(fn, arg)) if isinstance(exp, (Function, MatchFunction)): # Anonymous function return self.compile_function(env, exp, func_name=None) @@ -683,7 +683,7 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CList() elif isinstance(instr, NewClosure): new_type = CClo(instr.fn) - elif isinstance(instr, Call): + elif isinstance(instr, ClosureCall): new_type = CTop() else: raise NotImplementedError(f"SCCP {instr}") @@ -1265,7 +1265,7 @@ def test_apply_fn(self) -> None: bb0 { v0 = NewClosure v1 = Const<1> - v2 = Call v0, v1 + v2 = ClosureCall v0, v1 Return v2 } }""", @@ -1281,7 +1281,7 @@ def test_recursive_call(self) -> None: bb0 { v0 = NewClosure v1 = Const<5> - v2 = Call v0, v1 + v2 = ClosureCall v0, v1 Return v2 } }""", @@ -1305,7 +1305,7 @@ def test_recursive_call(self) -> None: bb5 { v3 = Const<1> v4 = IntSub v1, v3 - v5 = Call v0, v4 + v5 = ClosureCall v0, v4 v6 = IntMul v1, v5 Return v6 } @@ -1329,7 +1329,7 @@ def test_recursive_call(self) -> None: bb3 { v3 = Const<1> v4 = IntSub v1, v3 - v5 = Call v0, v4 + v5 = ClosureCall v0, v4 v6 = IntMul v1, v5 Return v6 } @@ -1350,7 +1350,7 @@ def test_apply_anonymous_function(self) -> None: bb0 { v0 = NewClosure v1 = Const<1> - v2 = Call v0, v1 + v2 = ClosureCall v0, v1 Return v2 } }""", From 46519817f28a59506a7f653309f1fb61e6a156c6 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:01:31 -0500 Subject: [PATCH 57/88] . --- ir.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/ir.py b/ir.py index da219f39..49037a2c 100644 --- a/ir.py +++ b/ir.py @@ -1614,28 +1614,29 @@ def compile_to_binary(source: str, memory: int, debug: bool) -> str: return out_file.name -class CompilerEndToEndTests(unittest.TestCase): - def _run(self, code: str) -> str: - import subprocess +def _run(code: str) -> str: + import subprocess - binary = compile_to_binary(code, memory=4096, debug=True) - result = subprocess.run([binary], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) - return result.stdout + binary = compile_to_binary(code, memory=4096, debug=True) + result = subprocess.run([binary], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) + return result.stdout + +class CompilerEndToEndTests(unittest.TestCase): def test_int(self) -> None: - self.assertEqual(self._run("1"), "1\n") + self.assertEqual(_run("1"), "1\n") def test_int_add(self) -> None: - self.assertEqual(self._run("1 + 2"), "3\n") + self.assertEqual(_run("1 + 2"), "3\n") def test_fun_id(self) -> None: - self.assertEqual(self._run("a -> a"), "\n") + self.assertEqual(_run("a -> a"), "\n") def test_match_int(self) -> None: - self.assertEqual(self._run("| 1 -> 2"), "\n") + self.assertEqual(_run("| 1 -> 2"), "\n") def test_call_match_int(self) -> None: - self.assertEqual(self._run("(| 1 -> 2) 1"), "2\n") + self.assertEqual(_run("(| 1 -> 2) 1"), "2\n") if __name__ == "__main__": From b5d18a6c650cb40606c3e7215b07218a2facff17 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:05:11 -0500 Subject: [PATCH 58/88] . --- ir.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ir.py b/ir.py index 49037a2c..d588a6f3 100644 --- a/ir.py +++ b/ir.py @@ -1529,10 +1529,8 @@ def test_const_list(self) -> None: self.assertEqual(analysis.instr_type[returned], CList()) -def compile_to_binary(source: str, memory: int, debug: bool) -> str: - import shlex +def compile_to_c(source: str) -> str: import subprocess - import sysconfig import tempfile program = parse(tokenize(source)) @@ -1608,17 +1606,26 @@ def compile_to_binary(source: str, memory: int, debug: bool) -> str: """, file=c_file, ) + return c_file.name + + +def compile_to_binary(c_name: str) -> str: + import subprocess + import tempfile + cc = os.environ.get("CC", "tcc") with tempfile.NamedTemporaryFile(mode="w", suffix=".out", delete=False) as out_file: - subprocess.run([cc, "-o", out_file.name, c_file.name], check=True) + subprocess.run([cc, "-o", out_file.name, c_name], check=True) return out_file.name def _run(code: str) -> str: import subprocess + import tempfile - binary = compile_to_binary(code, memory=4096, debug=True) - result = subprocess.run([binary], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) + c_name = compile_to_c(code) + binary_name = compile_to_binary(c_name) + result = subprocess.run([binary_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) return result.stdout From 92e00e7ffca7acbc80bdb21d39a045b84db41b44 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:06:17 -0500 Subject: [PATCH 59/88] . --- ir.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ir.py b/ir.py index d588a6f3..46135d82 100644 --- a/ir.py +++ b/ir.py @@ -1639,6 +1639,9 @@ def test_int_add(self) -> None: def test_fun_id(self) -> None: self.assertEqual(_run("a -> a"), "\n") + def test_call_fun_id(self) -> None: + self.assertEqual(_run("(a -> a) 3"), "3\n") + def test_match_int(self) -> None: self.assertEqual(_run("| 1 -> 2"), "\n") From a257baff8085cf02565ba13baa78f6b403e7b25b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:07:49 -0500 Subject: [PATCH 60/88] Port some tests over --- ir.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ir.py b/ir.py index 46135d82..92c6af6a 100644 --- a/ir.py +++ b/ir.py @@ -1648,6 +1648,15 @@ def test_match_int(self) -> None: def test_call_match_int(self) -> None: self.assertEqual(_run("(| 1 -> 2) 1"), "2\n") + def test_var(self) -> None: + self.assertEqual(_run("a . a = 1"), "1\n") + + def test_function(self) -> None: + self.assertEqual(_run("f 1 . f = x -> x + 1"), "2\n") + + def test_match_int_fallthrough(self) -> None: + self.assertEqual(_run("f 3 . f = | 1 -> 2 | 3 -> 4"), "4\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 782eaadfdc3faba1cab3082e18f504e7602003f0 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:11:12 -0500 Subject: [PATCH 61/88] . --- ir.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ir.py b/ir.py index 92c6af6a..e3f00ee2 100644 --- a/ir.py +++ b/ir.py @@ -387,6 +387,7 @@ def _decl(ty: str, rhs: str) -> str: return f"{ty} {gvn.name(instr)} = {rhs};\n" def op(idx: int) -> str: + assert isinstance(instr, HasOperands) return gvn.name(instr.operands[idx]) if isinstance(instr, Const): @@ -399,11 +400,11 @@ def op(idx: int) -> str: return _handle(self.params[instr.idx]) if isinstance(instr, NewClosure): result = _handle(f"mkclosure(heap, {instr.fn.name()}, {len(instr.operands)})") - for idx, op in enumerate(instr.operands): - result += f"closure_set({gvn.name(op)}, {idx}, {gvn.name(op)});\n" + for idx, opnd in enumerate(instr.operands): + result += f"closure_set({gvn.name(instr)}, {idx}, {gvn.name(opnd)});\n" return result if isinstance(instr, IsIntEqualWord): - return _decl("bool", f"{gvn.name(instr.operands[0])} == mksmallint({instr.expected})") + return _decl("bool", f"{op(0)} == mksmallint({instr.expected})") if isinstance(instr, ClosureCall): return _handle(f"closure_call({op(0)}, {op(1)})") raise NotImplementedError(type(instr)) From 121e9d56b56f11b7254be7909e1aeb653669eefa Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:13:13 -0500 Subject: [PATCH 62/88] Test closure vars --- ir.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index e3f00ee2..a7599e1c 100644 --- a/ir.py +++ b/ir.py @@ -403,10 +403,12 @@ def op(idx: int) -> str: for idx, opnd in enumerate(instr.operands): result += f"closure_set({gvn.name(instr)}, {idx}, {gvn.name(opnd)});\n" return result - if isinstance(instr, IsIntEqualWord): - return _decl("bool", f"{op(0)} == mksmallint({instr.expected})") + if isinstance(instr, ClosureRef): + return _handle(f"closure_get({op(0)}, {instr.idx})") if isinstance(instr, ClosureCall): return _handle(f"closure_call({op(0)}, {op(1)})") + if isinstance(instr, IsIntEqualWord): + return _decl("bool", f"{op(0)} == mksmallint({instr.expected})") raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: @@ -1640,6 +1642,9 @@ def test_int_add(self) -> None: def test_fun_id(self) -> None: self.assertEqual(_run("a -> a"), "\n") + def test_closed_vars(self) -> None: + self.assertEqual(_run("((a -> a + b) 3) . b = 4"), "7\n") + def test_call_fun_id(self) -> None: self.assertEqual(_run("(a -> a) 3"), "3\n") From 315819b6ed8b51d51eb3495bff6dac3cffc1a841 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:17:16 -0500 Subject: [PATCH 63/88] . --- ir.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ir.py b/ir.py index a7599e1c..9c7c3650 100644 --- a/ir.py +++ b/ir.py @@ -586,6 +586,9 @@ def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, Apply): fn = self.compile(env, exp.func) arg = self.compile(env, exp.arg) + # TODO(max): Separate out into ClosureFn and DirectCall and then we + # can later replace the ClosureFn with known C function pointer in + # an optimization pass return self.emit(ClosureCall(fn, arg)) if isinstance(exp, (Function, MatchFunction)): # Anonymous function From 19b8b04c53e4feed0081af5a4ee40c39301c2df8 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:23:36 -0500 Subject: [PATCH 64/88] Optimize functions in to_c tests --- ir.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/ir.py b/ir.py index 9c7c3650..7199b6a5 100644 --- a/ir.py +++ b/ir.py @@ -627,6 +627,11 @@ def has_value(self) -> bool: return self.value is not None +@dataclasses.dataclass +class CBool(ConstantLattice): + value: Optional[bool] = None + + @dataclasses.dataclass class CClo(ConstantLattice): value: Optional[IRFunction] = None @@ -641,6 +646,8 @@ def union(self: ConstantLattice, other: ConstantLattice) -> ConstantLattice: return self if isinstance(self, CInt) and isinstance(other, CInt): return self if self.value == other.value else CInt() + if isinstance(self, CBool) and isinstance(other, CBool): + return self if self.value == other.value else CBool() return CBottom() @@ -678,6 +685,19 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CList() elif isinstance(instr, Return): pass + elif isinstance(instr, MatchFail): + pass + elif isinstance(instr, CondBranch): + match self.type_of(instr.operands[0]): + case CBool(True): + block_worklist.append(instr.conseq) + case CBool(False): + block_worklist.append(instr.alt) + case CBottom(): + pass + case _: + block_worklist.append(instr.conseq) + block_worklist.append(instr.alt) elif isinstance(instr, IntAdd): match (self.type_of(instr.operands[0]), self.type_of(instr.operands[1])): case (CInt(int(l)), CInt(int(r))): @@ -691,6 +711,16 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CClo(instr.fn) elif isinstance(instr, ClosureCall): new_type = CTop() + elif isinstance(instr, Param): + new_type = CTop() + elif isinstance(instr, ClosureRef): + new_type = CTop() + elif isinstance(instr, IsIntEqualWord): + match self.type_of(instr.operands[0]): + case CInt(int(i)) if i == instr.expected: + new_type = CBool(True) + case _: + new_type = CBool() else: raise NotImplementedError(f"SCCP {instr}") old_type = self.type_of(instr) @@ -1535,6 +1565,11 @@ def test_const_list(self) -> None: self.assertEqual(analysis.instr_type[returned], CList()) +def opt(fn: IRFunction) -> None: + CleanCFG(fn).run() + SCCP(fn).run() + + def compile_to_c(source: str) -> str: import subprocess import tempfile @@ -1542,6 +1577,8 @@ def compile_to_c(source: str) -> str: program = parse(tokenize(source)) compiler = Compiler() compiler.compile_body({}, program) + for fn in compiler.fns: + opt(fn) c_code = compiler.to_c() dirname = os.path.dirname(__file__) with tempfile.NamedTemporaryFile(mode="w", suffix=".c", delete=False) as c_file: From 84eacde358c27ceecc7f812d1bb01cf337d5157f Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:28:09 -0500 Subject: [PATCH 65/88] Constant fold --- ir.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ir.py b/ir.py index 7199b6a5..238ce7cd 100644 --- a/ir.py +++ b/ir.py @@ -42,6 +42,7 @@ class InstrId: data: dict[Instr, int] = dataclasses.field(default_factory=dict) def __getitem__(self, instr: Instr) -> int: + instr = instr.find() id = self.data.get(instr) if id is not None: return id @@ -416,7 +417,7 @@ def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, se for instr in block.instrs: if isinstance(instr, Control): break - f.write(self._instr_to_c(instr, gvn, doms)) + f.write(self._instr_to_c(instr.find(), gvn, doms)) assert isinstance(instr, Control) if isinstance(instr, Return): f.write(f"return {gvn.name(instr.operands[0])};\n") @@ -1567,7 +1568,12 @@ def test_const_list(self) -> None: def opt(fn: IRFunction) -> None: CleanCFG(fn).run() - SCCP(fn).run() + instr_type = SCCP(fn).run() + for block in fn.cfg.rpo(): + for instr in block.instrs: + match instr_type[instr]: + case CInt(int(i)): + instr.make_equal_to(Const(Int(i))) def compile_to_c(source: str) -> str: From db79a5bcb52448b7c80e89a03864dd9ccb957631 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 25 Jan 2025 23:32:30 -0500 Subject: [PATCH 66/88] Tests opts on CFG --- ir.py | 75 +++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/ir.py b/ir.py index 238ce7cd..1abf37f2 100644 --- a/ir.py +++ b/ir.py @@ -304,6 +304,7 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str: for block in self.rpo(): result += f" {block.name()} {{\n" for instr in block.instrs: + instr = instr.find() if isinstance(instr, Control): result += f" {instr.to_string(gvn)}\n" else: @@ -775,11 +776,11 @@ def remove_unreachable_blocks(self) -> bool: self.fn.cfg.blocks = self.fn.cfg.rpo() return len(self.fn.cfg.blocks) != num_blocks +def _parse(source: str) -> Object: + return parse(tokenize(source)) -class IRTests(unittest.TestCase): - def _parse(self, source: str) -> Object: - return parse(tokenize(source)) +class IRTests(unittest.TestCase): def test_int(self) -> None: compiler = Compiler() compiler.compile_body({}, Int(1)) @@ -810,7 +811,7 @@ def test_str(self) -> None: def test_add_int(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("1 + 2")) + compiler.compile_body({}, _parse("1 + 2")) self.assertEqual( compiler.fn.to_string(InstrId()), """\ @@ -826,7 +827,7 @@ def test_add_int(self) -> None: def test_sub_int(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("1 - 2")) + compiler.compile_body({}, _parse("1 - 2")) self.assertEqual( compiler.fn.to_string(InstrId()), """\ @@ -842,7 +843,7 @@ def test_sub_int(self) -> None: def test_less_int(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("1 < 2")) + compiler.compile_body({}, _parse("1 < 2")) self.assertEqual( compiler.fn.to_string(InstrId()), """\ @@ -858,7 +859,7 @@ def test_less_int(self) -> None: def test_empty_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("[]")) + compiler.compile_body({}, _parse("[]")) self.assertEqual( compiler.fn.to_string(InstrId()), """\ @@ -872,7 +873,7 @@ def test_empty_list(self) -> None: def test_const_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("[1, 2]")) + compiler.compile_body({}, _parse("[1, 2]")) self.assertEqual( compiler.fn.to_string(InstrId()), """\ @@ -890,7 +891,7 @@ def test_const_list(self) -> None: def test_non_const_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("a -> [a]")) + compiler.compile_body({}, _parse("a -> [a]")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -907,7 +908,7 @@ def test_non_const_list(self) -> None: def test_let(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("a . a = 1")) + compiler.compile_body({}, _parse("a . a = 1")) self.assertEqual( compiler.fn.to_string(InstrId()), """\ @@ -921,7 +922,7 @@ def test_let(self) -> None: def test_fun_id(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("a -> a")) + compiler.compile_body({}, _parse("a -> a")) self.assertEqual( compiler.fns[0].to_string(InstrId()), """\ @@ -946,7 +947,7 @@ def test_fun_id(self) -> None: def test_fun_closure(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("a -> b -> a + b")) + compiler.compile_body({}, _parse("a -> b -> a + b")) self.assertEqual(len(compiler.fns), 3) self.assertEqual( compiler.fns[0].to_string(InstrId()), @@ -986,7 +987,7 @@ def test_fun_closure(self) -> None: def test_fun_const_closure(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("(a -> a + b) . b = 1")) + compiler.compile_body({}, _parse("(a -> a + b) . b = 1")) self.assertEqual(len(compiler.fns), 2) self.assertEqual( compiler.fns[0].to_string(InstrId()), @@ -1055,7 +1056,7 @@ def test_match_no_cases(self) -> None: def test_match_one_case(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("| 1 -> 2 + 3")) + compiler.compile_body({}, _parse("| 1 -> 2 + 3")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -1083,7 +1084,7 @@ def test_match_one_case(self) -> None: def test_match_two_cases(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("| 1 -> 2 | 3 -> 4")) + compiler.compile_body({}, _parse("| 1 -> 2 | 3 -> 4")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -1117,7 +1118,7 @@ def test_match_two_cases(self) -> None: def test_match_var(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("| a -> a + 1")) + compiler.compile_body({}, _parse("| a -> a + 1")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -1140,7 +1141,7 @@ def test_match_var(self) -> None: def test_match_empty_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("| [] -> 1")) + compiler.compile_body({}, _parse("| [] -> 1")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -1170,7 +1171,7 @@ def test_match_empty_list(self) -> None: def test_match_one_item_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("| [a] -> a + 1")) + compiler.compile_body({}, _parse("| [a] -> a + 1")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -1210,7 +1211,7 @@ def test_match_one_item_list(self) -> None: def test_match_two_item_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("| [a, b] -> a + b")) + compiler.compile_body({}, _parse("| [a, b] -> a + b")) self.assertEqual( compiler.fns[1].to_string(InstrId()), """\ @@ -1294,7 +1295,7 @@ def test_match_two_item_list(self) -> None: def test_apply_fn(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("f 1 . f = x -> x + 1")) + compiler.compile_body({}, _parse("f 1 . f = x -> x + 1")) self.assertEqual( compiler.fns[0].to_string(InstrId()), """\ @@ -1310,7 +1311,7 @@ def test_apply_fn(self) -> None: def test_recursive_call(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("fact 5 . fact = | 0 -> 1 | n -> n * fact (n - 1)")) + compiler.compile_body({}, _parse("fact 5 . fact = | 0 -> 1 | n -> n * fact (n - 1)")) self.assertEqual( compiler.fns[0].to_string(InstrId()), """\ @@ -1379,7 +1380,7 @@ def test_recursive_call(self) -> None: def test_apply_anonymous_function(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("((x -> x + 1) 1)")) + compiler.compile_body({}, _parse("((x -> x + 1) 1)")) self.assertEqual( compiler.fns[0].to_string(InstrId()), """\ @@ -1514,9 +1515,6 @@ def test_dom(self) -> None: class SCCPTests(unittest.TestCase): - def _parse(self, source: str) -> Object: - return parse(tokenize(source)) - def test_int(self) -> None: compiler = Compiler() compiler.compile_body({}, Int(1)) @@ -1527,7 +1525,7 @@ def test_int(self) -> None: def test_int_add(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("1 + 2 + 3")) + compiler.compile_body({}, _parse("1 + 2 + 3")) analysis = SCCP(compiler.fn) result = analysis.run() entry = compiler.fn.cfg.entry @@ -1545,7 +1543,7 @@ def test_int_add(self) -> None: def test_empty_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("[]")) + compiler.compile_body({}, _parse("[]")) analysis = SCCP(compiler.fn) analysis.run() return_instr = compiler.fn.cfg.entry.instrs[-1] @@ -1556,7 +1554,7 @@ def test_empty_list(self) -> None: def test_const_list(self) -> None: compiler = Compiler() - compiler.compile_body({}, self._parse("[1, 2]")) + compiler.compile_body({}, _parse("[1, 2]")) analysis = SCCP(compiler.fn) analysis.run() return_instr = compiler.fn.cfg.entry.instrs[-1] @@ -1576,6 +1574,27 @@ def opt(fn: IRFunction) -> None: instr.make_equal_to(Const(Int(i))) +class OptTests(unittest.TestCase): + def test_int_add(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("1 + 2 + 3")) + opt(compiler.fn) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + v1 = Const<2> + v2 = Const<3> + v3 = Const<5> + v4 = Const<6> + Return v4 + } +}""", + ) + + def compile_to_c(source: str) -> str: import subprocess import tempfile From c0a8e9a10e12d3324a2758f0b80ee3fa381814f7 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 15:22:17 -0500 Subject: [PATCH 67/88] Support allocating empty records --- ir.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/ir.py b/ir.py index 1abf37f2..5b0d1272 100644 --- a/ir.py +++ b/ir.py @@ -225,6 +225,11 @@ def to_string(self, gvn: InstrId) -> str: return f"{stem} " + ", ".join(f"{gvn.name(op)}" for op in self.operands) +@dataclasses.dataclass(eq=False) +class NewRecord(Instr): + num_fields: int + + Env = Dict[str, Instr] @@ -411,6 +416,8 @@ def op(idx: int) -> str: return _handle(f"closure_call({op(0)}, {op(1)})") if isinstance(instr, IsIntEqualWord): return _decl("bool", f"{op(0)} == mksmallint({instr.expected})") + if isinstance(instr, NewRecord): + return _handle(f"mkrecord(heap, {instr.num_fields})") raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: @@ -595,6 +602,13 @@ def compile(self, env: Env, exp: Object) -> Instr: if isinstance(exp, (Function, MatchFunction)): # Anonymous function return self.compile_function(env, exp, func_name=None) + if isinstance(exp, Record): + num_fields = len(exp.data) + result = self.emit(NewRecord(num_fields)) + for idx, (key, value_exp) in enumerate(exp.data.items()): + value = self.compile(env, value_exp) + self.emit(RecordSet(result, idx, key, value)) + return result raise NotImplementedError(f"exp {type(exp)} {exp}") def to_c(self) -> str: @@ -717,6 +731,8 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CTop() elif isinstance(instr, ClosureRef): new_type = CTop() + elif isinstance(instr, NewRecord): + new_type = CTop() elif isinstance(instr, IsIntEqualWord): match self.type_of(instr.operands[0]): case CInt(int(i)) if i == instr.expected: @@ -776,6 +792,7 @@ def remove_unreachable_blocks(self) -> bool: self.fn.cfg.blocks = self.fn.cfg.rpo() return len(self.fn.cfg.blocks) != num_blocks + def _parse(source: str) -> Object: return parse(tokenize(source)) @@ -1398,6 +1415,20 @@ def test_apply_anonymous_function(self) -> None: entry = compiler.fns[0].cfg.entry self.assertEqual(analysis.instr_type[entry.instrs[0]], CClo(compiler.fns[1])) + def test_empty_record(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("{}")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewRecord + Return v0 + } +}""", + ) + class RPOTests(unittest.TestCase): def test_one_block(self) -> None: @@ -1728,6 +1759,9 @@ def test_function(self) -> None: def test_match_int_fallthrough(self) -> None: self.assertEqual(_run("f 3 . f = | 1 -> 2 | 3 -> 4"), "4\n") + def test_empty_record(self) -> None: + self.assertEqual(_run("{}"), "{}\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From c71f89939083f4ee905ff12d6d60ebf19316d1c6 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 15:47:42 -0500 Subject: [PATCH 68/88] Add RecordSet, record_keys --- ir.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/ir.py b/ir.py index 5b0d1272..7273b6bc 100644 --- a/ir.py +++ b/ir.py @@ -230,6 +230,21 @@ class NewRecord(Instr): num_fields: int +@dataclasses.dataclass(eq=False) +class RecordSet(HasOperands): + idx: int + name: str + + def __init__(self, rec: Instr, idx: int, name: str, value: Instr) -> None: + self.operands = [rec, value] + self.idx = idx + self.name = name + + def to_string(self, gvn: InstrId) -> str: + stem = f"{type(self).__name__}<{self.idx}; {self.name}> " + return stem + ", ".join(f"{gvn.name(op)}" for op in self.operands) + + Env = Dict[str, Instr] @@ -418,6 +433,8 @@ def op(idx: int) -> str: return _decl("bool", f"{op(0)} == mksmallint({instr.expected})") if isinstance(instr, NewRecord): return _handle(f"mkrecord(heap, {instr.num_fields})") + if isinstance(instr, RecordSet): + return f"record_set({op(0)}, {instr.idx}, (struct record_field){{.key={instr.name}, .value={op(1)}}});\n" raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: @@ -453,6 +470,7 @@ def __init__(self) -> None: self.gensym_counter: int = 0 self.fn: IRFunction = entry self.block: Block = entry.cfg.entry + self.record_keys: Dict[str, int] = {} def new_function(self, params: list[str]) -> IRFunction: result = IRFunction(len(self.fns), params) @@ -607,12 +625,32 @@ def compile(self, env: Env, exp: Object) -> Instr: result = self.emit(NewRecord(num_fields)) for idx, (key, value_exp) in enumerate(exp.data.items()): value = self.compile(env, value_exp) - self.emit(RecordSet(result, idx, key, value)) + self.emit(RecordSet(result, idx, self.record_key(key), value)) return result raise NotImplementedError(f"exp {type(exp)} {exp}") + def record_key(self, key: str) -> str: + if key not in self.record_keys: + self.record_keys[key] = len(self.record_keys) + return f"Record_{key}" + def to_c(self) -> str: - return "\n".join(fn.to_c() for fn in self.fns) + with io.StringIO() as f: + if self.record_keys: + print("const char* record_keys[] = {", file=f) + for key in self.record_keys: + print(f'"{key}",', file=f) + print("};", file=f) + print("enum {", file=f) + for key, idx in self.record_keys.items(): + print(f"Record_{key} = {idx},", file=f) + print("};", file=f) + else: + # Pacify the C compiler + print("const char* record_keys[] = { NULL };", file=f) + for fn in self.fns: + print(fn.to_c(), file=f) + return f.getvalue() @dataclasses.dataclass @@ -733,6 +771,8 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CTop() elif isinstance(instr, NewRecord): new_type = CTop() + elif isinstance(instr, RecordSet): + new_type = CTop() elif isinstance(instr, IsIntEqualWord): match self.type_of(instr.operands[0]): case CInt(int(i)) if i == instr.expected: @@ -1429,6 +1469,40 @@ def test_empty_record(self) -> None: }""", ) + def test_record_with_one_field(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("{a=1}")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewRecord + v1 = Const<1> + v2 = RecordSet<0; Record_a> v0, v1 + Return v0 + } +}""", + ) + + def test_record_with_two_fields(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("{a=1, b=2}")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewRecord + v1 = Const<1> + v2 = RecordSet<0; Record_a> v0, v1 + v3 = Const<2> + v4 = RecordSet<1; Record_b> v0, v3 + Return v0 + } +}""", + ) + class RPOTests(unittest.TestCase): def test_one_block(self) -> None: @@ -1690,9 +1764,6 @@ def compile_to_c(source: str) -> str: const char* variant_names[] = {{ "UNDEF", }}; -const char* record_keys[] = {{ - "UNDEF", -}}; int main() {{ struct space space = make_space(MEMORY_SIZE); init_heap(heap, space); @@ -1762,6 +1833,12 @@ def test_match_int_fallthrough(self) -> None: def test_empty_record(self) -> None: self.assertEqual(_run("{}"), "{}\n") + def test_record_with_one_field(self) -> None: + self.assertEqual(_run("{a=1}"), "{a = 1}\n") + + def test_record_with_two_fields(self) -> None: + self.assertEqual(_run("{a=1, b=2}"), "{a = 1, b = 2}\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From a7f0cfd4198f54399213783fd0a6f9b5e10c21b6 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 16:09:44 -0500 Subject: [PATCH 69/88] Add RecordGet --- ir.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/ir.py b/ir.py index 7273b6bc..b7f91df7 100644 --- a/ir.py +++ b/ir.py @@ -245,6 +245,19 @@ def to_string(self, gvn: InstrId) -> str: return stem + ", ".join(f"{gvn.name(op)}" for op in self.operands) +@dataclasses.dataclass(eq=False) +class RecordGet(HasOperands): + name: str + + def __init__(self, rec: Instr, name: str) -> None: + self.operands = [rec] + self.name = name + + def to_string(self, gvn: InstrId) -> str: + stem = f"{type(self).__name__}<{self.name}> " + return stem + ", ".join(f"{gvn.name(op)}" for op in self.operands) + + Env = Dict[str, Instr] @@ -302,6 +315,16 @@ def succs(self) -> tuple[Block, ...]: return (self.conseq, self.alt) +@dataclasses.dataclass(init=False, eq=False) +class Guard(HasOperands): + pass + + +@dataclasses.dataclass(init=False, eq=False) +class GuardNonNull(Guard): + pass + + @dataclasses.dataclass class CFG: blocks: list[Block] = dataclasses.field(init=False, default_factory=list) @@ -435,6 +458,10 @@ def op(idx: int) -> str: return _handle(f"mkrecord(heap, {instr.num_fields})") if isinstance(instr, RecordSet): return f"record_set({op(0)}, {instr.idx}, (struct record_field){{.key={instr.name}, .value={op(1)}}});\n" + if isinstance(instr, RecordGet): + return _handle(f"record_get({op(0)}, {instr.name})") + if isinstance(instr, GuardNonNull): + return f"if ({op(0)} == NULL) {{ abort(); }}\n" + _handle(op(0)) raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: @@ -627,6 +654,13 @@ def compile(self, env: Env, exp: Object) -> Instr: value = self.compile(env, value_exp) self.emit(RecordSet(result, idx, self.record_key(key), value)) return result + if isinstance(exp, Access): + assert isinstance(exp.at, Var), f"List access not supported" + record = self.compile(env, exp.obj) + key_idx = self.record_key(exp.at.name) + # TODO(max): Guard that it's a Record + value = self.emit(RecordGet(record, key_idx)) + return self.emit(GuardNonNull(value)) raise NotImplementedError(f"exp {type(exp)} {exp}") def record_key(self, key: str) -> str: @@ -773,6 +807,10 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CTop() elif isinstance(instr, RecordSet): new_type = CTop() + elif isinstance(instr, RecordGet): + new_type = CTop() + elif isinstance(instr, GuardNonNull): + new_type = CTop() elif isinstance(instr, IsIntEqualWord): match self.type_of(instr.operands[0]): case CInt(int(i)) if i == instr.expected: @@ -1839,6 +1877,15 @@ def test_record_with_one_field(self) -> None: def test_record_with_two_fields(self) -> None: self.assertEqual(_run("{a=1, b=2}"), "{a = 1, b = 2}\n") + def test_record_builder(self) -> None: + self.assertEqual(_run("f 1 2 . f = x -> y -> {a = x, b = y}"), "{a = 1, b = 2}\n") + + def test_record_access(self) -> None: + self.assertEqual(_run("rec@a . rec = {a = 1, b = 2}"), "1\n") + + def test_record_builder_access(self) -> None: + self.assertEqual(_run("(f 1 2)@a . f = x -> y -> {a = x, b = y}"), "1\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 134cc1fbf42a548aefce9d93da07b8a52f8ab8ff Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 16:10:31 -0500 Subject: [PATCH 70/88] Add IR test for RecordGet --- ir.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ir.py b/ir.py index b7f91df7..147e8f7e 100644 --- a/ir.py +++ b/ir.py @@ -1541,6 +1541,26 @@ def test_record_with_two_fields(self) -> None: }""", ) + def test_record_access(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("{a=1, b=2}@a")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = NewRecord + v1 = Const<1> + v2 = RecordSet<0; Record_a> v0, v1 + v3 = Const<2> + v4 = RecordSet<1; Record_b> v0, v3 + v5 = RecordGet v0 + v6 = GuardNonNull v5 + Return v6 + } +}""", + ) + class RPOTests(unittest.TestCase): def test_one_block(self) -> None: From 9358bfd6ec456d39df9eb453c0867bd7384c23ba Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 16:25:29 -0500 Subject: [PATCH 71/88] Support Hole --- ir.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/ir.py b/ir.py index 147e8f7e..11ba7c6d 100644 --- a/ir.py +++ b/ir.py @@ -436,8 +436,11 @@ def op(idx: int) -> str: return gvn.name(instr.operands[idx]) if isinstance(instr, Const): - if isinstance(instr.value, Int): - return _handle(f"mksmallint({instr.value.value})") + value = instr.value + if isinstance(value, Int): + return _handle(f"mksmallint({value.value})") + if isinstance(value, Hole): + return _decl("struct object*", "hole()") if isinstance(instr, IntAdd): operands = ", ".join(gvn.name(op) for op in instr.operands) return _handle(f"num_add({operands})") @@ -606,7 +609,7 @@ def compile_function(self, env: Env, exp: Function | MatchFunction, func_name: O return result def compile(self, env: Env, exp: Object) -> Instr: - if isinstance(exp, (Int, String)): + if isinstance(exp, (Int, String, Hole)): return self.emit(Const(exp)) if isinstance(exp, Var): return env[exp.name] @@ -1561,6 +1564,20 @@ def test_record_access(self) -> None: }""", ) + def test_hole(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("()")) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<()> + Return v0 + } +}""", + ) + class RPOTests(unittest.TestCase): def test_one_block(self) -> None: @@ -1906,6 +1923,9 @@ def test_record_access(self) -> None: def test_record_builder_access(self) -> None: self.assertEqual(_run("(f 1 2)@a . f = x -> y -> {a = x, b = y}"), "1\n") + def test_hole(self) -> None: + self.assertEqual(_run("()"), "()\n") + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From fb830e8612e381b06018d56936cea5ebb3a1d73f Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:25:26 -0500 Subject: [PATCH 72/88] Add string --- ir.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ir.py b/ir.py index 11ba7c6d..cd271c2b 100644 --- a/ir.py +++ b/ir.py @@ -441,6 +441,9 @@ def op(idx: int) -> str: return _handle(f"mksmallint({value.value})") if isinstance(value, Hole): return _decl("struct object*", "hole()") + if isinstance(value, String): + string_repr = json.dumps(value.value) + return _handle(f"mkstring(heap, {string_repr}, {len(value.value)})") if isinstance(instr, IntAdd): operands = ", ".join(gvn.name(op) for op in instr.operands) return _handle(f"num_add({operands})") @@ -1578,6 +1581,20 @@ def test_hole(self) -> None: }""", ) + def test_string(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse('"hello"')) + self.assertEqual( + compiler.fns[0].to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<"hello"> + Return v0 + } +}""", + ) + class RPOTests(unittest.TestCase): def test_one_block(self) -> None: @@ -1926,6 +1943,9 @@ def test_record_builder_access(self) -> None: def test_hole(self) -> None: self.assertEqual(_run("()"), "()\n") + def test_string(self) -> None: + self.assertEqual(_run('"hello"'), '"hello"\n') + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 2b4cd06cc0c9925657d3d58063094ccdaf91479b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:26:36 -0500 Subject: [PATCH 73/88] Compile int sub --- ir.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ir.py b/ir.py index cd271c2b..170e584d 100644 --- a/ir.py +++ b/ir.py @@ -798,6 +798,12 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CInt(l + r) case (CInt(_), CInt(_)): new_type = CInt() + elif isinstance(instr, IntSub): + match (self.type_of(instr.operands[0]), self.type_of(instr.operands[1])): + case (CInt(int(l)), CInt(int(r))): + new_type = CInt(l - r) + case (CInt(_), CInt(_)): + new_type = CInt() elif isinstance(instr, ListCons): if isinstance(self.type_of(instr.operands[1]), CList): new_type = CList() @@ -1898,6 +1904,9 @@ def test_int(self) -> None: def test_int_add(self) -> None: self.assertEqual(_run("1 + 2"), "3\n") + def test_int_sub(self) -> None: + self.assertEqual(_run("1 - 2"), "-1\n") + def test_fun_id(self) -> None: self.assertEqual(_run("a -> a"), "\n") From 171e3d9086a28e67ef2d142ccf13554d0774d60a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:36:26 -0500 Subject: [PATCH 74/88] Just use RPO in to_c --- ir.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ir.py b/ir.py index 170e584d..eb686547 100644 --- a/ir.py +++ b/ir.py @@ -415,7 +415,9 @@ def to_c(self) -> str: f.write("HANDLES();\n") for param in self.params: f.write(f"GC_PROTECT({param});\n") - self._to_c(f, self.cfg.entry, InstrId(), self.cfg.doms()) + gvn = InstrId() + for block in self.cfg.rpo(): + self._to_c(f, block, gvn) f.write("}") return f.getvalue() return @@ -424,7 +426,7 @@ def c_decl(self) -> str: params = ", ".join(f"struct object *{param}" for param in self.params) return f"struct object *fn{self.id}({params})\n" - def _instr_to_c(self, instr: Instr, gvn: InstrId, doms: dict[Block, set[Block]]) -> str: + def _instr_to_c(self, instr: Instr, gvn: InstrId) -> str: def _handle(rhs: str) -> str: return f"OBJECT_HANDLE({gvn.name(instr)}, {rhs});\n" @@ -470,24 +472,21 @@ def op(idx: int) -> str: return f"if ({op(0)} == NULL) {{ abort(); }}\n" + _handle(op(0)) raise NotImplementedError(type(instr)) - def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId, doms: dict[Block, set[Block]]) -> None: + def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId) -> None: f.write(f"{block.name()}:;\n") for instr in block.instrs: if isinstance(instr, Control): break - f.write(self._instr_to_c(instr.find(), gvn, doms)) + f.write(self._instr_to_c(instr.find(), gvn)) assert isinstance(instr, Control) if isinstance(instr, Return): f.write(f"return {gvn.name(instr.operands[0])};\n") elif isinstance(instr, Jump): f.write(f"goto {instr.target.name()};\n") - self._to_c(f, instr.target, gvn, doms) elif isinstance(instr, CondBranch): f.write( f"if ({gvn.name(instr.operands[0])}) {{ goto {instr.conseq.name()}; }} else {{ goto {instr.alt.name()}; }}\n" ) - self._to_c(f, instr.conseq, gvn, doms) - self._to_c(f, instr.alt, gvn, doms) elif isinstance(instr, MatchFail): f.write("""fprintf(stderr, "no matching cases\\n");\n""") f.write("abort();\n") From 1bcff9784dd43de0d89903031bf39b23c3c6bd4a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:36:44 -0500 Subject: [PATCH 75/88] Add instructions for list matching --- ir.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ir.py b/ir.py index eb686547..c3decfb6 100644 --- a/ir.py +++ b/ir.py @@ -446,6 +446,10 @@ def op(idx: int) -> str: if isinstance(value, String): string_repr = json.dumps(value.value) return _handle(f"mkstring(heap, {string_repr}, {len(value.value)})") + if isinstance(value, List): + if not value.items: + return _decl("struct object*", "empty_list()") + raise NotImplementedError("const", type(value)) if isinstance(instr, IntAdd): operands = ", ".join(gvn.name(op) for op in instr.operands) return _handle(f"num_add({operands})") @@ -470,6 +474,16 @@ def op(idx: int) -> str: return _handle(f"record_get({op(0)}, {instr.name})") if isinstance(instr, GuardNonNull): return f"if ({op(0)} == NULL) {{ abort(); }}\n" + _handle(op(0)) + if isinstance(instr, ListCons): + return _handle(f"list_cons({op(0)}, {op(1)})") + if isinstance(instr, ListFirst): + return _handle(f"list_first({op(0)})") + if isinstance(instr, ListRest): + return _handle(f"list_rest({op(0)})") + if isinstance(instr, IsList): + return _decl("bool", f"is_list({op(0)})") + if isinstance(instr, IsEmptyList): + return _decl("bool", f"{op(0)} == empty_list()") raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId) -> None: @@ -828,6 +842,18 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CBool(True) case _: new_type = CBool() + elif isinstance(instr, IsList): + match self.type_of(instr.operands[0]): + case CList(_): + new_type = CBool(True) + case _: + new_type = CBool() + elif isinstance(instr, IsEmptyList): + new_type = CBool() + elif isinstance(instr, ListFirst): + new_type = CTop() + elif isinstance(instr, ListRest): + new_type = CTop() else: raise NotImplementedError(f"SCCP {instr}") old_type = self.type_of(instr) @@ -1921,6 +1947,10 @@ def test_match_int(self) -> None: def test_call_match_int(self) -> None: self.assertEqual(_run("(| 1 -> 2) 1"), "2\n") + def test_match_list(self) -> None: + self.assertEqual(_run("f [1, 2] . f = | [1, 2] -> 3 | [4, 5] -> 6"), "3\n") + self.assertEqual(_run("f [4, 5] . f = | [1, 2] -> 3 | [4, 5] -> 6"), "6\n") + def test_var(self) -> None: self.assertEqual(_run("a . a = 1"), "1\n") From 869fc4414b1d3e15f766226e0b10dbda77ce3f27 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:40:11 -0500 Subject: [PATCH 76/88] Simplify --- ir.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/ir.py b/ir.py index c3decfb6..7b7121f3 100644 --- a/ir.py +++ b/ir.py @@ -484,29 +484,26 @@ def op(idx: int) -> str: return _decl("bool", f"is_list({op(0)})") if isinstance(instr, IsEmptyList): return _decl("bool", f"{op(0)} == empty_list()") + if isinstance(instr, Return): + return f"return {op(0)};\n" + if isinstance(instr, Jump): + return f"goto {instr.target.name()};\n" + if isinstance(instr, CondBranch): + return f"if ({op(0)}) {{ goto {instr.conseq.name()}; }} else {{ goto {instr.alt.name()}; }}\n" + if isinstance(instr, MatchFail): + return "\n".join( + [ + """fprintf(stderr, "no matching cases\\n");""", + "abort();", + "return NULL;\n", # Pacify the C compiler + ] + ) raise NotImplementedError(type(instr)) def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId) -> None: f.write(f"{block.name()}:;\n") for instr in block.instrs: - if isinstance(instr, Control): - break f.write(self._instr_to_c(instr.find(), gvn)) - assert isinstance(instr, Control) - if isinstance(instr, Return): - f.write(f"return {gvn.name(instr.operands[0])};\n") - elif isinstance(instr, Jump): - f.write(f"goto {instr.target.name()};\n") - elif isinstance(instr, CondBranch): - f.write( - f"if ({gvn.name(instr.operands[0])}) {{ goto {instr.conseq.name()}; }} else {{ goto {instr.alt.name()}; }}\n" - ) - elif isinstance(instr, MatchFail): - f.write("""fprintf(stderr, "no matching cases\\n");\n""") - f.write("abort();\n") - f.write("return NULL;\n") # Pacify the C compiler - else: - raise NotImplementedError(instr) class Compiler: From e89aec4c075c3571d269e8807b4ea60399b873b3 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:42:01 -0500 Subject: [PATCH 77/88] . --- ir.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ir.py b/ir.py index 7b7121f3..48fd48f9 100644 --- a/ir.py +++ b/ir.py @@ -1943,6 +1943,7 @@ def test_match_int(self) -> None: def test_call_match_int(self) -> None: self.assertEqual(_run("(| 1 -> 2) 1"), "2\n") + self.assertEqual(_run("(| 1 -> 2 | 3 -> 4) 3"), "4\n") def test_match_list(self) -> None: self.assertEqual(_run("f [1, 2] . f = | [1, 2] -> 3 | [4, 5] -> 6"), "3\n") From 94d2699352f818581ca21bf8074dd12779d50b44 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:42:02 -0500 Subject: [PATCH 78/88] . --- ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ir.py b/ir.py index 48fd48f9..f46f6a79 100644 --- a/ir.py +++ b/ir.py @@ -841,7 +841,7 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CBool() elif isinstance(instr, IsList): match self.type_of(instr.operands[0]): - case CList(_): + case CList(): new_type = CBool(True) case _: new_type = CBool() From 9498b66c5ff31f9210a994e687f22c66e889ac39 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 26 Jan 2025 19:50:17 -0500 Subject: [PATCH 79/88] Support matching list spread --- ir.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/ir.py b/ir.py index f46f6a79..252dcf7e 100644 --- a/ir.py +++ b/ir.py @@ -555,7 +555,11 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success # the_list = self.emit(RefineType(param, CList())) the_list = param for i, pattern_item in enumerate(pattern.items): - assert not isinstance(pattern_item, Spread) + if isinstance(pattern_item, Spread): + if pattern_item.name: + updates[pattern_item.name] = the_list + self.emit(Jump(success)) + return updates # Not enough elements is_empty = self.emit(IsEmptyList(the_list)) is_nonempty_block = self.fn.cfg.new_block() @@ -1422,6 +1426,68 @@ def test_match_two_item_list(self) -> None: }""", ) + def test_match_list_spread(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("| [_, ...xs] -> xs")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb2 { + v2 = IsList v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = IsEmptyList v1 + CondBranch v3, bb1, bb5 + } + bb5 { + v4 = ListFirst v1 + Jump bb6 + } + bb6 { + v5 = ListRest v1 + Jump bb3 + } + bb3 { + Return v5 + } + bb1 { + MatchFail + } +}""", + ) + CleanCFG(compiler.fns[1]).run() + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + v2 = IsList v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = IsEmptyList v1 + CondBranch v3, bb1, bb5 + } + bb5 { + v4 = ListFirst v1 + v5 = ListRest v1 + Return v5 + } + bb1 { + MatchFail + } +}""", + ) + def test_apply_fn(self) -> None: compiler = Compiler() compiler.compile_body({}, _parse("f 1 . f = x -> x + 1")) @@ -1949,6 +2015,9 @@ def test_match_list(self) -> None: self.assertEqual(_run("f [1, 2] . f = | [1, 2] -> 3 | [4, 5] -> 6"), "3\n") self.assertEqual(_run("f [4, 5] . f = | [1, 2] -> 3 | [4, 5] -> 6"), "6\n") + def test_match_list_spread(self) -> None: + self.assertEqual(_run("f [4, 5] . f = | [_, ...xs] -> xs"), "[5]\n") + def test_var(self) -> None: self.assertEqual(_run("a . a = 1"), "1\n") From 41792440c1a12f65443e13126b64393072b3d62c Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 27 Jan 2025 17:02:23 -0500 Subject: [PATCH 80/88] Support record pattern matching --- ir.py | 267 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 267 insertions(+) diff --git a/ir.py b/ir.py index 252dcf7e..dee8eff8 100644 --- a/ir.py +++ b/ir.py @@ -87,6 +87,15 @@ def to_string(self, gvn: InstrId) -> str: return f"{type(self).__name__}<{self.value}>" +@dataclasses.dataclass(eq=False) +class CConst(Instr): + type: str + value: str + + def to_string(self, gvn: InstrId) -> str: + return f"{type(self).__name__}<{self.type}; {self.value}>" + + @dataclasses.dataclass(eq=False) class Param(Instr): idx: int @@ -130,6 +139,12 @@ class IntLess(HasOperands): pass +# TODO(max): Maybe start work on boxing/unboxing in the IR. +@dataclasses.dataclass(init=False, eq=False) +class CEqual(HasOperands): + pass + + @dataclasses.dataclass(init=False, eq=False) class RefineType(HasOperands): def __init__(self, value: Instr, ty: ConstantLattice) -> None: @@ -230,6 +245,11 @@ class NewRecord(Instr): num_fields: int +@dataclasses.dataclass(init=False, eq=False) +class IsRecord(HasOperands): + pass + + @dataclasses.dataclass(eq=False) class RecordSet(HasOperands): idx: int @@ -258,6 +278,11 @@ def to_string(self, gvn: InstrId) -> str: return stem + ", ".join(f"{gvn.name(op)}" for op in self.operands) +@dataclasses.dataclass(init=False, eq=False) +class RecordNumFields(HasOperands): + pass + + Env = Dict[str, Instr] @@ -484,6 +509,14 @@ def op(idx: int) -> str: return _decl("bool", f"is_list({op(0)})") if isinstance(instr, IsEmptyList): return _decl("bool", f"{op(0)} == empty_list()") + if isinstance(instr, IsRecord): + return _decl("bool", f"is_record({op(0)})") + if isinstance(instr, RecordNumFields): + return _decl("uword", f"record_num_fields({op(0)})") + if isinstance(instr, CConst): + return _decl(instr.type, instr.value) + if isinstance(instr, CEqual): + return _decl("bool", f"{op(0)} == {op(1)}") if isinstance(instr, Return): return f"return {op(0)};\n" if isinstance(instr, Jump): @@ -575,6 +608,35 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success is_empty = self.emit(IsEmptyList(the_list)) self.emit(CondBranch(is_empty, success, fallthrough)) return updates + if isinstance(pattern, Record): + is_record = self.emit(IsRecord(param)) + updates = {} + is_record_block = self.fn.cfg.new_block() + self.emit(CondBranch(is_record, is_record_block, fallthrough)) + self.block = is_record_block + for key, pattern_value in pattern.data.items(): + if isinstance(pattern_value, Spread): + if pattern_value.name: + raise NotImplementedError("named record spread not yet supported") + self.emit(Jump(success)) + return updates + key_idx = self.record_key(key) + record_value = self.emit(RecordGet(param, key_idx)) + is_null = self.emit(CEqual(record_value, self.emit(CConst("struct object*", "NULL")))) + recursive_block = self.fn.cfg.new_block() + self.emit(CondBranch(is_null, fallthrough, recursive_block)) + self.block = recursive_block + pattern_success = self.fn.cfg.new_block() + # Recursive pattern match + updates.update( + self.compile_match_pattern(env, record_value, pattern_value, pattern_success, fallthrough) + ) + self.block = pattern_success + # Too many fields + num_fields = self.emit(RecordNumFields(param)) + cmp = self.emit(CEqual(num_fields, self.emit(CConst("uword", str(len(pattern.data)))))) + self.emit(CondBranch(cmp, success, fallthrough)) + return updates raise NotImplementedError(f"pattern {type(pattern)} {pattern}") def compile_body(self, env: Env, exp: Object) -> None: @@ -855,6 +917,14 @@ def run(self) -> dict[Instr, ConstantLattice]: new_type = CTop() elif isinstance(instr, ListRest): new_type = CTop() + elif isinstance(instr, IsRecord): + new_type = CTop() + elif isinstance(instr, CConst): + new_type = CTop() + elif isinstance(instr, CEqual): + new_type = CTop() + elif isinstance(instr, RecordNumFields): + new_type = CTop() else: raise NotImplementedError(f"SCCP {instr}") old_type = self.type_of(instr) @@ -1488,6 +1558,194 @@ def test_match_list_spread(self) -> None: }""", ) + def test_match_empty_record(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("| {} -> 1")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb2 { + v2 = IsRecord v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = RecordNumFields v1 + v4 = CConst + v5 = CEqual v3, v4 + CondBranch v5, bb3, bb1 + } + bb1 { + MatchFail + } + bb3 { + v6 = Const<1> + Return v6 + } +}""", + ) + + def test_match_one_item_record(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("| {a=1} -> 1")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb2 { + v2 = IsRecord v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = RecordGet v1 + v4 = CConst + v5 = CEqual v3, v4 + CondBranch v5, bb1, bb5 + } + bb5 { + v6 = IsIntEqualWord v3, 1 + CondBranch v6, bb6, bb1 + } + bb6 { + v7 = RecordNumFields v1 + v8 = CConst + v9 = CEqual v7, v8 + CondBranch v9, bb3, bb1 + } + bb3 { + v10 = Const<1> + Return v10 + } + bb1 { + MatchFail + } +}""", + ) + + def test_match_two_item_record(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("| {a=1, b=2} -> 3")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb2 { + v2 = IsRecord v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = RecordGet v1 + v4 = CConst + v5 = CEqual v3, v4 + CondBranch v5, bb1, bb5 + } + bb5 { + v6 = IsIntEqualWord v3, 1 + CondBranch v6, bb6, bb1 + } + bb6 { + v7 = RecordGet v1 + v8 = CConst + v9 = CEqual v7, v8 + CondBranch v9, bb1, bb7 + } + bb7 { + v10 = IsIntEqualWord v7, 2 + CondBranch v10, bb8, bb1 + } + bb8 { + v11 = RecordNumFields v1 + v12 = CConst + v13 = CEqual v11, v12 + CondBranch v13, bb3, bb1 + } + bb3 { + v14 = Const<3> + Return v14 + } + bb1 { + MatchFail + } +}""", + ) + + def test_match_record_spread(self) -> None: + compiler = Compiler() + compiler.compile_body({}, _parse("| {a=a, ...} -> a")) + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + Jump bb2 + } + bb2 { + v2 = IsRecord v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = RecordGet v1 + v4 = CConst + v5 = CEqual v3, v4 + CondBranch v5, bb1, bb5 + } + bb5 { + Jump bb6 + } + bb6 { + Jump bb3 + } + bb3 { + Return v3 + } + bb1 { + MatchFail + } +}""", + ) + CleanCFG(compiler.fns[1]).run() + self.assertEqual( + compiler.fns[1].to_string(InstrId()), + """\ +fn1 { + bb0 { + v0 = Param<0; $clo> + v1 = Param<1; arg_0> + v2 = IsRecord v1 + CondBranch v2, bb4, bb1 + } + bb4 { + v3 = RecordGet v1 + v4 = CConst + v5 = CEqual v3, v4 + CondBranch v5, bb1, bb5 + } + bb5 { + Return v3 + } + bb1 { + MatchFail + } +}""", + ) + def test_apply_fn(self) -> None: compiler = Compiler() compiler.compile_body({}, _parse("f 1 . f = x -> x + 1")) @@ -2045,6 +2303,15 @@ def test_record_access(self) -> None: def test_record_builder_access(self) -> None: self.assertEqual(_run("(f 1 2)@a . f = x -> y -> {a = x, b = y}"), "1\n") + def test_match_record(self) -> None: + self.assertEqual(_run("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6"), "6\n") + + def test_match_record_too_few_keys(self) -> None: + self.assertEqual(_run("f {a = 4, b = 5} . f = | {a = _} -> 3 | {a = _, b = _} -> 6"), "6\n") + + def test_match_record_spread(self) -> None: + self.assertEqual(_run("f {a=1, b=2, c=3} . f = | {a=a, ...} -> a"), "1\n") + def test_hole(self) -> None: self.assertEqual(_run("()"), "()\n") From e2b9876a6f7226740ba6630b4cb8bb873e2eae46 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 27 Jan 2025 17:02:31 -0500 Subject: [PATCH 81/88] Debugging cleanup --- ir.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ir.py b/ir.py index dee8eff8..7f67abef 100644 --- a/ir.py +++ b/ir.py @@ -300,9 +300,14 @@ def name(self) -> str: def terminator(self) -> Control: result = self.instrs[-1] - assert isinstance(result, Control) + assert isinstance(result, Control), f"Expected Control but found {result}" return result + def succs(self) -> tuple[Block, ...]: + if not self.instrs: + return () + return self.terminator().succs() + @dataclasses.dataclass(eq=False) class Jump(Control): @@ -388,8 +393,7 @@ def rpo(self) -> list[Block]: def po_from(self, block: Block, result: list[Block], visited: set[Block]) -> None: visited.add(block) - terminator = block.terminator() - for succ in terminator.succs(): + for succ in block.succs(): if succ not in visited: self.po_from(succ, result, visited) result.append(block) @@ -398,7 +402,7 @@ def preds(self) -> dict[Block, set[Block]]: rpo = self.rpo() result: dict[Block, set[Block]] = {block: set() for block in rpo} for block in rpo: - for succ in block.terminator().succs(): + for succ in block.succs(): result[succ].add(block) return result From d4bc1e56bf3e00a63d7021072e5099e55ca8e386 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 27 Jan 2025 17:02:38 -0500 Subject: [PATCH 82/88] . --- ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ir.py b/ir.py index 7f67abef..e3511abe 100644 --- a/ir.py +++ b/ir.py @@ -8,7 +8,7 @@ import typing import unittest -from typing import Dict, Optional, Tuple +from typing import Dict, Optional from scrapscript import ( Access, From 257ff853052a71dbe9195506a18e9e1a2513c58f Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 1 Feb 2025 20:48:54 -0500 Subject: [PATCH 83/88] . --- ir.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/ir.py b/ir.py index e3511abe..c72046a0 100644 --- a/ir.py +++ b/ir.py @@ -802,7 +802,15 @@ def has_value(self) -> bool: @dataclasses.dataclass -class CBool(ConstantLattice): +class CCInt(ConstantLattice): + value: Optional[int] = None + + def has_value(self) -> bool: + return self.value is not None + + +@dataclasses.dataclass +class CCBool(ConstantLattice): value: Optional[bool] = None @@ -820,8 +828,8 @@ def union(self: ConstantLattice, other: ConstantLattice) -> ConstantLattice: return self if isinstance(self, CInt) and isinstance(other, CInt): return self if self.value == other.value else CInt() - if isinstance(self, CBool) and isinstance(other, CBool): - return self if self.value == other.value else CBool() + if isinstance(self, CCBool) and isinstance(other, CCBool): + return self if self.value == other.value else CCBool() return CBottom() @@ -863,9 +871,9 @@ def run(self) -> dict[Instr, ConstantLattice]: pass elif isinstance(instr, CondBranch): match self.type_of(instr.operands[0]): - case CBool(True): + case CCBool(True): block_worklist.append(instr.conseq) - case CBool(False): + case CCBool(False): block_worklist.append(instr.alt) case CBottom(): pass @@ -906,29 +914,29 @@ def run(self) -> dict[Instr, ConstantLattice]: elif isinstance(instr, IsIntEqualWord): match self.type_of(instr.operands[0]): case CInt(int(i)) if i == instr.expected: - new_type = CBool(True) + new_type = CCBool(True) case _: - new_type = CBool() + new_type = CCBool() elif isinstance(instr, IsList): match self.type_of(instr.operands[0]): case CList(): - new_type = CBool(True) + new_type = CCBool(True) case _: - new_type = CBool() + new_type = CCBool() elif isinstance(instr, IsEmptyList): - new_type = CBool() + new_type = CCBool() elif isinstance(instr, ListFirst): new_type = CTop() elif isinstance(instr, ListRest): new_type = CTop() elif isinstance(instr, IsRecord): - new_type = CTop() + new_type = CCBool() elif isinstance(instr, CConst): new_type = CTop() elif isinstance(instr, CEqual): - new_type = CTop() + new_type = CCBool() elif isinstance(instr, RecordNumFields): - new_type = CTop() + new_type = CCInt() else: raise NotImplementedError(f"SCCP {instr}") old_type = self.type_of(instr) From 6d885eed92aa4691245dc887c03ae5fa4042bb46 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 1 Feb 2025 20:50:16 -0500 Subject: [PATCH 84/88] . --- ir.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ir.py b/ir.py index c72046a0..49b170dd 100644 --- a/ir.py +++ b/ir.py @@ -853,6 +853,7 @@ def run(self) -> dict[Instr, ConstantLattice]: while block_worklist or instr_worklist: if instr_worklist and (instr := instr_worklist.pop(0)): + instr = instr.find() if isinstance(instr, HasOperands): for operand in instr.operands: if operand not in self.instr_uses: @@ -872,8 +873,10 @@ def run(self) -> dict[Instr, ConstantLattice]: elif isinstance(instr, CondBranch): match self.type_of(instr.operands[0]): case CCBool(True): + instr.make_equal_to(Jump(instr.conseq)) block_worklist.append(instr.conseq) case CCBool(False): + instr.make_equal_to(Jump(instr.alt)) block_worklist.append(instr.alt) case CBottom(): pass From cbc889684fcdf2b2cb3fc7e35ccec8d13145859b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 1 Feb 2025 20:53:49 -0500 Subject: [PATCH 85/88] . --- ir.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ir.py b/ir.py index 49b170dd..d4b8b436 100644 --- a/ir.py +++ b/ir.py @@ -887,6 +887,7 @@ def run(self) -> dict[Instr, ConstantLattice]: match (self.type_of(instr.operands[0]), self.type_of(instr.operands[1])): case (CInt(int(l)), CInt(int(r))): new_type = CInt(l + r) + instr.make_equal_to(Const(Int(l+r))) case (CInt(_), CInt(_)): new_type = CInt() elif isinstance(instr, IntSub): @@ -2105,6 +2106,21 @@ def test_int_add(self) -> None: }, ) + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<1> + v1 = Const<2> + v2 = Const<3> + v3 = Const<5> + v4 = Const<6> + Return v4 + } +}""", + ) + def test_empty_list(self) -> None: compiler = Compiler() compiler.compile_body({}, _parse("[]")) From b18f832d0f0bab71e76b886ba1189e6868327aac Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 1 Feb 2025 21:14:31 -0500 Subject: [PATCH 86/88] DCE; Nop --- ir.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 7 deletions(-) diff --git a/ir.py b/ir.py index d4b8b436..9e463f9c 100644 --- a/ir.py +++ b/ir.py @@ -79,6 +79,11 @@ def to_string(self, gvn: InstrId) -> str: return type(self).__name__ +@dataclasses.dataclass(eq=False) +class Nop(Instr): + pass + + @dataclasses.dataclass(eq=False) class Const(Instr): value: Object @@ -378,6 +383,8 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str: result += f" {block.name()} {{\n" for instr in block.instrs: instr = instr.find() + if isinstance(instr, Nop): + continue if isinstance(instr, Control): result += f" {instr.to_string(gvn)}\n" else: @@ -540,7 +547,10 @@ def op(idx: int) -> str: def _to_c(self, f: io.StringIO, block: Block, gvn: InstrId) -> None: f.write(f"{block.name()}:;\n") for instr in block.instrs: - f.write(self._instr_to_c(instr.find(), gvn)) + instr = instr.find() + if isinstance(instr, Nop): + continue + f.write(self._instr_to_c(instr, gvn)) class Compiler: @@ -995,6 +1005,45 @@ def remove_unreachable_blocks(self) -> bool: return len(self.fn.cfg.blocks) != num_blocks +@dataclasses.dataclass +class DeadCodeElimination: + fn: IRFunction + + def is_critical(self, instr: Instr) -> bool: + if isinstance(instr, Const): + return False + if isinstance(instr, IntAdd): + return False + # TODO(max): Add more. Track heap effects? + return True + + def run(self) -> None: + worklist: list[Instr] = [] + marked: set[Instr] = set() + blocks = self.fn.cfg.rpo() + # Mark + for block in blocks: + for instr in block.instrs: + instr = instr.find() + if self.is_critical(instr): + marked.add(instr) + worklist.append(instr) + while worklist: + instr = worklist.pop(0).find() + if isinstance(instr, HasOperands): + for op in instr.operands: + op = op.find() + if op not in marked: + marked.add(op) + worklist.append(op) + # Sweep + for block in blocks: + for instr in block.instrs: + instr = instr.find() + if instr not in marked: + instr.make_equal_to(Nop()) + + def _parse(source: str) -> Object: return parse(tokenize(source)) @@ -2144,6 +2193,46 @@ def test_const_list(self) -> None: self.assertEqual(analysis.instr_type[returned], CList()) +class DeadCodeEliminationTests(unittest.TestCase): + def test_remove_const(self) -> None: + compiler = Compiler() + compiler.emit(Const(1)) + compiler.emit(Const(2)) + compiler.emit(Const(3)) + four = compiler.emit(Const(4)) + compiler.emit(Return(four)) + DeadCodeElimination(compiler.fn).run() + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<4> + Return v0 + } +}""", + ) + + def test_remove_int_add(self) -> None: + compiler = Compiler() + one = compiler.emit(Const(1)) + two = compiler.emit(Const(2)) + compiler.emit(IntAdd(one, two)) + four = compiler.emit(Const(4)) + compiler.emit(Return(four)) + DeadCodeElimination(compiler.fn).run() + self.assertEqual( + compiler.fn.to_string(InstrId()), + """\ +fn0 { + bb0 { + v0 = Const<4> + Return v0 + } +}""", + ) + + def opt(fn: IRFunction) -> None: CleanCFG(fn).run() instr_type = SCCP(fn).run() @@ -2152,6 +2241,7 @@ def opt(fn: IRFunction) -> None: match instr_type[instr]: case CInt(int(i)): instr.make_equal_to(Const(Int(i))) + DeadCodeElimination(fn).run() class OptTests(unittest.TestCase): @@ -2164,12 +2254,8 @@ def test_int_add(self) -> None: """\ fn0 { bb0 { - v0 = Const<1> - v1 = Const<2> - v2 = Const<3> - v3 = Const<5> - v4 = Const<6> - Return v4 + v0 = Const<6> + Return v0 } }""", ) From 95f3b3df9e2214673d5d741a7d1742c68e941fa9 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 7 Feb 2025 18:30:51 -0500 Subject: [PATCH 87/88] . --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 94d70cfc..22b75785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,3 @@ line-length = 120 [tool.ruff.lint] ignore = ["E741"] - -[project] -name = "scrapscript" -version = "0.1.1" From 763af02b861d0fe18b5160c46de54809e71222cf Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 7 Feb 2025 18:33:20 -0500 Subject: [PATCH 88/88] . --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1235bd08..74d64bf2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,8 @@ jobs: run: CC=${{matrix.CC}} CFLAGS="-fsanitize=undefined ${{matrix.USE_STATIC_HEAP}}" uv run python compiler_tests.py - name: Run compiler tests with Valgrind run: CC=${{matrix.CC}} CFLAGS="${{matrix.USE_STATIC_HEAP}}" USE_VALGRIND=1 uv run python compiler_tests.py + - name: Run IR tests (remove when merged) + run: CC=${{matrix.CC}} CFLAGS="${{matrix.USE_STATIC_HEAP}}" uv run python ir.py run_compiler_unit_tests_other_cc: runs-on: ubuntu-latest steps: