Skip to content
239 changes: 238 additions & 1 deletion scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import code
import dataclasses
import enum
import functools
import http.server
import json
import logging
Expand Down Expand Up @@ -1254,6 +1255,91 @@ def bencode(obj: object) -> bytes:
raise NotImplementedError(f"bencode not implemented for {type(obj)}")


class JSCompiler:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏 you know what I'm thinking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports are our friends

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a prototype before doing it in scrap

def compile(self, env: Env, exp: Object) -> str:
if isinstance(exp, Int):
return str(exp.value)
if isinstance(exp, Binop):
left = self.compile(env, exp.left)
right = self.compile(env, exp.right)
return f"({left})" + BinopKind.to_str(exp.op) + f"({right})"
if isinstance(exp, Var):
# assert exp.name in env
return exp.name
if isinstance(exp, Where):
binding = exp.binding
assert isinstance(binding, Assign)
return self.compile_let(env, binding.name.name, binding.value, exp.body)
if isinstance(exp, Assign):
value = self.compile(env, exp.value)
return f"const {exp.name.name} = {value};\n"
if isinstance(exp, Apply):
func = self.compile(env, exp.func)
arg = self.compile(env, exp.arg)
return f"({func})({arg})"
if isinstance(exp, Function):
arg = self.compile(env, exp.arg)
body = self.compile(env, exp.body)
return f"({arg}) => ({body})"
if isinstance(exp, List):
items = [self.compile(env, item) for item in exp.items]
return "[" + ", ".join(items) + "]"
if isinstance(exp, MatchFunction):
err = "(() => {throw 'oh no'})()"
if not exp.cases:
return err
# TODO(max): Gensym arg name or something
arg = "__x"

def per_case(acc: str, case: MatchCase) -> str:
cond, body = self.compile_match_case(env, arg, case)
return f"({cond}) ? ({body}) : ({acc})"

return f"({arg}) => " + functools.reduce(
per_case,
reversed(exp.cases),
err,
)
if isinstance(exp, Symbol):
if exp.value in ("true", "false"):
return exp.value
return repr(exp.value)
if isinstance(exp, String):
return repr(exp.value)
if isinstance(exp, Access):
obj = self.compile(env, exp.obj)
if isinstance(exp.at, Int):
return f"{obj}[{exp.at}]"
assert isinstance(exp.at, Var)
return f"{obj}.{exp.at}"
if isinstance(exp, Record):
result = "{"
for key, rec_value in exp.data.items():
result += repr(key) + ":" + self.compile(env, rec_value) + ","
return result + "}"
raise NotImplementedError(type(exp), exp)

def compile_let(self, env: Env, name: str, value: Object, body: Object) -> str:
body_str = self.compile(env, body)
value_str = self.compile(env, value)
return f"(({name}) => ({body_str}))({value_str})"

def compile_match_case(self, env: Env, arg: str, case: MatchCase) -> Tuple[str, str]:
pattern = case.pattern
body = case.body
if isinstance(pattern, Int):
return f"{arg} === {pattern.value}", self.compile(env, body)
if isinstance(pattern, Var):
return "true", self.compile_let(env, pattern.name, Var(arg), body)
raise NotImplementedError(type(pattern))


def compile_exp_js(env: Env, exp: Object) -> str:
compiler = JSCompiler()
result = compiler.compile(env, exp)
return result


class Bdecoder:
def __init__(self, msg: str) -> None:
self.msg: str = msg
Expand Down Expand Up @@ -4158,6 +4244,105 @@ def test_pretty_print_symbol(self) -> None:
self.assertEqual(str(obj), "#x")


class JSCompilerTests(unittest.TestCase):
def test_compile_int(self) -> None:
exp = Int(123)
self.assertEqual(compile_exp_js({}, exp), "123")

def test_compile_binop_add(self) -> None:
exp = Binop(BinopKind.ADD, Int(3), Int(4))
self.assertEqual(compile_exp_js({}, exp), "(3)+(4)")

def test_compile_binop_rec(self) -> None:
exp = Binop(BinopKind.MUL, Binop(BinopKind.ADD, Int(3), Int(4)), Int(5))
self.assertEqual(compile_exp_js({}, exp), "((3)+(4))*(5)")

def test_compile_where(self) -> None:
exp = Where(Var("x"), Assign(Var("x"), Int(1)))
self.assertEqual(compile_exp_js({}, exp), "((x) => (x))(1)")

def test_compile_nested_where(self) -> None:
exp = parse(tokenize("x + y . x = 1 . y = 2"))
self.assertEqual(compile_exp_js({}, exp), "((y) => (((x) => ((x)+(y)))(1)))(2)")

def test_compile_apply(self) -> None:
exp = Apply(Var("f"), Var("x"))
self.assertEqual(compile_exp_js({}, exp), "(f)(x)")

def test_compile_apply_nested(self) -> None:
exp = Apply(Apply(Var("f"), Var("x")), Var("y"))
self.assertEqual(compile_exp_js({}, exp), "((f)(x))(y)")

def test_compile_function(self) -> None:
exp = Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1)))
self.assertEqual(compile_exp_js({}, exp), "(x) => ((x)+(1))")

def test_compile_function_nested(self) -> None:
exp = parse(tokenize("x -> y -> x + y"))
self.assertEqual(compile_exp_js({}, exp), "(x) => ((y) => ((x)+(y)))")

def test_compile_list(self) -> None:
exp = List([Binop(BinopKind.ADD, Int(1), Int(2)), Binop(BinopKind.MUL, Int(3), Int(4))])
self.assertEqual(compile_exp_js({}, exp), "[(1)+(2), (3)*(4)]")

def test_compile_match_function(self) -> None:
exp = parse(tokenize("| 1 -> 2 | 2 -> 3"))
self.assertEqual(
compile_exp_js({}, exp), "(__x) => (__x === 1) ? (2) : ((__x === 2) ? (3) : ((() => {throw 'oh no'})()))"
)

def test_compile_match_function_var(self) -> None:
exp = parse(tokenize("| 1 -> 2 | x -> x"))
self.assertEqual(
compile_exp_js({}, exp),
"(__x) => (__x === 1) ? (2) : ((true) ? (((x) => (x))(__x)) : ((() => {throw 'oh no'})()))",
)

def test_compile_symbol_bool_true(self) -> None:
exp = Symbol("true")
self.assertEqual(compile_exp_js({}, exp), "true")

def test_compile_symbol_bool_false(self) -> None:
exp = Symbol("false")
self.assertEqual(compile_exp_js({}, exp), "false")

def test_compile_symbol(self) -> None:
exp = Symbol("hello")
self.assertEqual(compile_exp_js({}, exp), "'hello'")

def test_compile_string(self) -> None:
exp = String("hello")
self.assertEqual(compile_exp_js({}, exp), "'hello'")

def test_compile_string_single_quotes(self) -> None:
exp = String("'hello'")
self.assertEqual(compile_exp_js({}, exp), "\"'hello'\"")

def test_compile_string_double_quotes(self) -> None:
exp = String('"hello"')
self.assertEqual(compile_exp_js({}, exp), "'\"hello\"'")

def test_compile_access_int(self) -> None:
exp = Access(Var("x"), Int(1))
self.assertEqual(compile_exp_js({}, exp), "x[1]")

def test_compile_access_field(self) -> None:
exp = Access(Var("x"), Var("y"))
self.assertEqual(compile_exp_js({}, exp), "x.y")

def test_compile_nested_access(self) -> None:
exp = Access(Access(Var("x"), Var("y")), Var("z"))
self.assertEqual(compile_exp_js({}, exp), "x.y.z")

def test_compile_empty_record(self) -> None:
exp = Record({})
self.assertEqual(compile_exp_js({}, exp), "{}")

def test_compile_record(self) -> None:
exp = Record({"a": Int(1), "b": Int(2)})
self.assertEqual(compile_exp_js({}, exp), "{'a':1,'b':2,}")


def fetch(url: Object) -> Object:
if not isinstance(url, String):
raise TypeError(f"fetch expected String, but got {type(url).__name__}")
Expand Down Expand Up @@ -4323,6 +4508,25 @@ def runsource(self, source: str, filename: str = "<input>", symbol: str = "singl
return False


class JSRepl(ScrapRepl):
def runsource(self, source: str, filename: str = "<input>", symbol: str = "single") -> bool:
try:
tokens = tokenize(source)
logger.debug("Tokens: %s", tokens)
ast = parse(tokens)
logger.debug("AST: %s", ast)
result = compile_exp_js(self.env, ast)
print(result)
except UnexpectedEOFError:
# Need to read more text
return True
except ParseError as e:
print(f"Parse error: {e}", file=sys.stderr)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return False


def eval_command(args: argparse.Namespace) -> None:
if args.debug:
logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -4352,7 +4556,7 @@ def repl_command(args: argparse.Namespace) -> None:
if args.debug:
logging.basicConfig(level=logging.DEBUG)

repl = ScrapRepl()
repl = JSRepl() if args.js else ScrapRepl()
if readline:
repl.enable_readline()
repl.interact(banner="")
Expand Down Expand Up @@ -4390,6 +4594,7 @@ def main() -> None:
repl = subparsers.add_parser("repl")
repl.set_defaults(func=repl_command)
repl.add_argument("--debug", action="store_true")
repl.add_argument("--js", action="store_true")

test = subparsers.add_parser("test")
test.set_defaults(func=test_command)
Expand Down Expand Up @@ -4420,5 +4625,37 @@ def main() -> None:
args.func(args)


print(
compile_exp_js(
{},
parse(
tokenize(
"""
rand_array (new_generator 42) 0 100 10

. rand_array = gen -> min -> max -> n -> n |>
| 0 -> []
| n -> (rand_val >+ rand_array new_gen min max (n - 1)
. rand_val = get_int new_gen
. new_gen = next gen min max)

-- from Java's java.util.Random
. new_generator = seed -> ({params = params, seed = seed, state = state}
. params = {mod = 281474976710656, mult = 25214903917, inc = 11}
. state = {min = 0, max = 0})

. get_int = gen -> $$floor (get gen)

. get = gen -> (gen@seed / gen@params@mod) * (gen@state@max - gen@state@min)

. next = gen -> min -> max -> ({params = gen@params, seed = next_seed, state = {min = min, max = max}}
. next_seed = gen@state@min + (gen@seed * gen@params@mult + gen@params@inc) % gen@params@mod)
"""
)
),
)
)


if __name__ == "__main__":
main()