diff --git a/README.md b/README.md index 8cb9946a..3412d2e0 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,10 @@ We support python3.8+. ```bash # With a file -python3 scrapscript.py eval examples/0_home/factorial.scrap +python3 scrapscript.py eval < examples/0_home/factorial.scrap # With a string literal -python3 scrapscript.py apply "1 + 2" +python3 scrapscript.py eval <<< '1 + 2' # With a REPL python3 scrapscript.py repl @@ -28,10 +28,10 @@ or with [Cosmopolitan](https://justine.lol/cosmopolitan/index.html): ./util/build-com # With a file -./scrapscript.com eval examples/0_home/factorial.scrap +./scrapscript.com eval < examples/0_home/factorial.scrap # With a string literal -./scrapscript.com apply "1 + 2" +./scrapscript.com eval <<< '1 + 2' # With a REPL ./scrapscript.com repl @@ -44,10 +44,10 @@ or with Docker: ```bash # With a file (mount your local directory) -docker run --mount type=bind,source="$(pwd)",target=/mnt -i -t ghcr.io/tekknolagi/scrapscript:trunk eval /mnt/examples/0_home/factorial.scrap +docker run --mount type=bind,source="$(pwd)",target=/mnt -i -t ghcr.io/tekknolagi/scrapscript:trunk eval < /mnt/examples/0_home/factorial.scrap # With a string literal -docker run -i -t ghcr.io/tekknolagi/scrapscript:trunk apply "1 + 2" +docker run -i -t ghcr.io/tekknolagi/scrapscript:trunk eval <<< '1 + 2' # With a REPL docker run -i -t ghcr.io/tekknolagi/scrapscript:trunk repl diff --git a/scrapscript.py b/scrapscript.py index 667e0edf..39ddbf71 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -2347,7 +2347,7 @@ def eval_command(args: argparse.Namespace) -> None: if args.debug: logging.basicConfig(level=logging.DEBUG) - program = args.program_file.read() + program = args.input.read() tokens = tokenize(program) logger.debug("Tokens: %s", tokens) ast = parse(tokens) @@ -2356,30 +2356,17 @@ def eval_command(args: argparse.Namespace) -> None: print(pretty(result)) -def check_command(args: argparse.Namespace) -> None: +def type_command(args: argparse.Namespace) -> None: if args.debug: logging.basicConfig(level=logging.DEBUG) - program = args.program_file.read() + program = args.input.read() tokens = tokenize(program) logger.debug("Tokens: %s", tokens) ast = parse(tokens) logger.debug("AST: %s", ast) result = infer_type(ast, OP_ENV) - result = minimize(result) - print(result) - - -def apply_command(args: argparse.Namespace) -> None: - if args.debug: - logging.basicConfig(level=logging.DEBUG) - - tokens = tokenize(args.program) - logger.debug("Tokens: %s", tokens) - ast = parse(tokens) - logger.debug("AST: %s", ast) - result = eval_exp(boot_env(), ast) - print(pretty(result)) + print(str(result)) # Use str() instead of pretty() since we're dealing with types def repl_command(args: argparse.Namespace) -> None: @@ -2476,11 +2463,58 @@ def compile_command(args: argparse.Namespace) -> None: def flat_command(args: argparse.Namespace) -> None: - prog = parse(tokenize(sys.stdin.read())) - serializer = Serializer() - serializer.serialize(prog) - sys.stdout.buffer.write(serializer.output) + if args.debug: + logging.basicConfig(level=logging.DEBUG) + + if args.mode == "parse": + # Read input, parse it, and serialize it + program = args.input.read() + tokens = tokenize(program) + ast = parse(tokens) + serializer = Serializer() + serializer.serialize(ast) + sys.stdout.buffer.write(serializer.output) + elif args.mode == "print": + # Read serialized input, deserialize it, and pretty print it + if hasattr(args.input, 'buffer'): + serialized_data = args.input.buffer.read() + else: + serialized_data = args.input.read() + deserializer = Deserializer(serialized_data) + ast = deserializer.parse() + print(pretty(ast)) + else: + raise ValueError(f"Unknown mode: {args.mode}") + + +def format_command(args: argparse.Namespace) -> None: + if args.debug: + logging.basicConfig(level=logging.DEBUG) + + program = args.input.read() + try: + tokens = tokenize(program) + ast = parse(tokens) + print(pretty(ast)) + except (ParseError, UnexpectedTokenError, InvalidTokenError, UnexpectedEOFError) as e: + raise Exception(f"Invalid syntax: {e}") + +def pipe_command(args: argparse.Namespace) -> None: + if args.debug: + logging.basicConfig(level=logging.DEBUG) + # Parse the input program + program = args.input.read() + tokens = tokenize(program) + ast = parse(tokens) + + # Parse the pipe expression + pipe_tokens = tokenize(args.command) + pipe_expr = parse(pipe_tokens) + + # Apply the pipe expression to the input program + result = eval_exp(boot_env(), Apply(pipe_expr, ast)) + print(pretty(result)) def main() -> None: parser = argparse.ArgumentParser(prog="scrapscript") @@ -2492,18 +2526,19 @@ def main() -> None: eval_ = subparsers.add_parser("eval") eval_.set_defaults(func=eval_command) - eval_.add_argument("program_file", type=argparse.FileType("r")) + eval_.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) eval_.add_argument("--debug", action="store_true") - check = subparsers.add_parser("check") - check.set_defaults(func=check_command) - check.add_argument("program_file", type=argparse.FileType("r")) - check.add_argument("--debug", action="store_true") + pipe = subparsers.add_parser("pipe") + pipe.set_defaults(func=pipe_command) + pipe.add_argument("command", help="Expression to apply to the input program") + pipe.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + pipe.add_argument("--debug", action="store_true") - apply = subparsers.add_parser("apply") - apply.set_defaults(func=apply_command) - apply.add_argument("program") - apply.add_argument("--debug", action="store_true") + type_ = subparsers.add_parser("type") + type_.set_defaults(func=type_command) + type_.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + type_.add_argument("--debug", action="store_true") comp = subparsers.add_parser("compile") comp.set_defaults(func=compile_command) @@ -2520,6 +2555,14 @@ def main() -> None: flat = subparsers.add_parser("flat") flat.set_defaults(func=flat_command) + flat.add_argument("mode", choices=["parse", "print"]) + flat.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + flat.add_argument("--debug", action="store_true") + + format_ = subparsers.add_parser("format") + format_.set_defaults(func=format_command) + format_.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + format_.add_argument("--debug", action="store_true") args = parser.parse_args() if not args.command: diff --git a/scrapscript_tests.py b/scrapscript_tests.py index e1626d0e..4e6840e1 100644 --- a/scrapscript_tests.py +++ b/scrapscript_tests.py @@ -1,6 +1,12 @@ import unittest import re +import io +import sys +import argparse from typing import Optional +from contextlib import contextmanager +import os +import unittest.mock # ruff: noqa: F405 # ruff: noqa: F403 @@ -4051,5 +4057,170 @@ def test_pretty_print_variant(self) -> None: self.assertEqual(pretty(obj), "#x (a -> b)") +@contextmanager +def captured_output(): + new_out, new_err = io.StringIO(), io.StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err + +@contextmanager +def captured_binary_output(): + new_out = io.BytesIO() + old_out = sys.stdout + try: + sys.stdout = type('', (), {'buffer': new_out})() + yield new_out + finally: + sys.stdout = old_out + +class CLITests(unittest.TestCase): + def setUp(self): + self.stdin = io.StringIO("42") + + def test_eval_command(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, input=self.stdin) + eval_command(args) + self.assertEqual(out.getvalue().strip(), "42") + + def test_flat_command_parse(self): + with captured_binary_output() as out: + args = argparse.Namespace(debug=False, mode="parse", input=self.stdin) + flat_command(args) + serialized_data = out.getvalue() + ast = deserialize(serialized_data) + self.assertEqual(pretty(ast), "42") + + def test_flat_command_print(self): + with captured_output() as (out, err): + # Create a serialized Int(42) object + serializer = Serializer() + serializer.serialize(Int(42)) + args = argparse.Namespace(debug=False, mode="print", input=io.BytesIO(serializer.output)) + flat_command(args) + self.assertEqual(out.getvalue().strip(), "42") + + def test_pipe_command(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, command="x -> x + 1", input=self.stdin) + pipe_command(args) + self.assertEqual(out.getvalue().strip(), "43") + + def test_type_command(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, input=self.stdin) + type_command(args) + self.assertEqual(out.getvalue().strip(), "int") + + def test_format_command(self): + input_text = "x->x+1" + with captured_output() as (out, err): + args = argparse.Namespace(input=io.StringIO(input_text), debug=False) + format_command(args) + self.assertEqual(out.getvalue().strip(), "x -> x + 1") + + @unittest.mock.patch('readline.read_history_file') + @unittest.mock.patch('readline.write_history_file') + def test_repl_command(self, mock_write_history, mock_read_history): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False) + # Mock sys.stdin to provide input and EOF + original_stdin = sys.stdin + sys.stdin = io.StringIO("42\n\x04") # 42 followed by EOF (Ctrl+D) + try: + repl_command(args) + finally: + sys.stdin = original_stdin + self.assertIn("42", out.getvalue()) + + def test_eval_command_invalid_syntax(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, input=io.StringIO("invalid syntax")) + with self.assertRaises(Exception): + eval_command(args) + + def test_flat_command_invalid_syntax(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, mode="parse", input=io.StringIO("invalid syntax")) + with self.assertRaises(Exception): + flat_command(args) + + def test_pipe_command_invalid_syntax(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, command="x -> x + 1", input=io.StringIO("invalid syntax")) + with self.assertRaises(Exception): + pipe_command(args) + + def test_type_command_invalid_syntax(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, input=io.StringIO("invalid syntax")) + with self.assertRaises(Exception): + type_command(args) + + def test_format_command_invalid_syntax(self): + with captured_output() as (out, err): + args = argparse.Namespace(input=io.StringIO("invalid syntax"), debug=False) + with self.assertRaisesRegex(Exception, "Invalid syntax:"): + format_command(args) + + def test_compile_command(self): + with captured_output() as (out, err): + args = argparse.Namespace( + file="test.scrap", + output="output.c", + format=False, + compile=False, + memory=None, + run=False, + debug=False, + check=False, + platform=os.path.join(os.path.dirname(__file__), "cli.c") + ) + # Create a temporary test file + with open("test.scrap", "w") as f: + f.write("42") + try: + compile_command(args) + # Check if output.c was created + self.assertTrue(os.path.exists("output.c")) + finally: + # Clean up + if os.path.exists("test.scrap"): + os.remove("test.scrap") + if os.path.exists("output.c"): + os.remove("output.c") + + def test_compile_command_invalid_file(self): + with captured_output() as (out, err): + args = argparse.Namespace( + file="nonexistent.scrap", + output="output.c", + format=False, + compile=False, + memory=None, + run=False, + debug=False, + check=False, + platform=os.path.join(os.path.dirname(__file__), "cli.c") + ) + with self.assertRaises(FileNotFoundError): + compile_command(args) + + def test_format_command_invalid_syntax(self): + with captured_output() as (out, err): + args = argparse.Namespace(input=io.StringIO("(1 2"), debug=False) # Unmatched parenthesis + with self.assertRaisesRegex(Exception, "Invalid syntax:"): + format_command(args) + + def test_pipe_command_invalid_command(self): + with captured_output() as (out, err): + args = argparse.Namespace(debug=False, command="invalid ->", input=self.stdin) + with self.assertRaises(Exception): + pipe_command(args) + if __name__ == "__main__": unittest.main()