From 73bd6b49cf1b97f39c0edf5b2228721018c56cd8 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 20 Jan 2024 15:31:26 -0500 Subject: [PATCH 01/11] Parse guards (ew) --- scrapscript.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/scrapscript.py b/scrapscript.py index 8cceeb14..c7b582c8 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -253,6 +253,8 @@ def read_var(self, first_char: str) -> Token: while self.has_input() and is_identifier_char(c := self.peek_char()): self.read_char() buf += c + if buf == "guard": + return self.make_token(Operator, "guard") return self.make_token(Name, buf) def read_bytes(self) -> Token: @@ -303,6 +305,7 @@ def xp(n: float) -> Prec: "::": lp(2000), "@": rp(1001), "": rp(1000), + "guard": rp(5.5), ">>": lp(14), "<<": lp(14), "^": rp(13), @@ -342,7 +345,7 @@ def xp(n: float) -> Prec: HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values()) -OPER_CHARS = set("".join(PS.keys())) +OPER_CHARS = set("".join(PS.keys())) - set("guard") assert " " not in OPER_CHARS @@ -493,6 +496,8 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object": elif op == Operator("@"): # TODO: revisit whether to use @ or . for field access l = Access(l, parse(tokens, pr)) + elif op == Operator("guard"): + l = Guard(l, parse(tokens, pr)) else: assert not isinstance(op, Juxt) assert isinstance(op, Operator) @@ -865,6 +870,12 @@ def __str__(self) -> str: return f"EnvObject(keys={self.env.keys()})" +@dataclass(eq=True, frozen=True, unsafe_hash=True) +class Guard(Object): + pattern: Object + cond: Object + + @dataclass(eq=True, frozen=True, unsafe_hash=True) class MatchCase(Object): pattern: Object @@ -2328,6 +2339,18 @@ def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None: def test_parse_symbol_returns_symbol(self) -> None: self.assertEqual(parse([SymbolToken("abc")]), Symbol("abc")) + def test_parse_guard(self) -> None: + self.assertEqual( + parse(tokenize("| x guard y -> x")), + MatchFunction([MatchCase(Guard(Var("x"), Var("y")), Var("x"))]), + ) + + def test_parse_guard_exp(self) -> None: + self.assertEqual( + parse(tokenize("| x guard x==1 -> x")), + MatchFunction([MatchCase(Guard(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1))), Var("x"))]), + ) + class MatchTests(unittest.TestCase): def test_match_with_equal_ints_returns_empty_dict(self) -> None: From db0825809ed51a8d368193f397bb4d245048c6b3 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 20 Jan 2024 15:41:04 -0500 Subject: [PATCH 02/11] wip --- scrapscript.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index c7b582c8..d4099482 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -367,6 +367,13 @@ def parse_assign(tokens: typing.List[Token], p: float = 0) -> "Assign": return assign +def build_match_case(expr: "Object") -> "MatchCase": + if not isinstance(expr, Function): + raise ParseError(f"expected function in match expression {expr!r}") + arg, body = expr.arg, expr.body + return MatchCase(arg, body) + + def parse(tokens: typing.List[Token], p: float = 0) -> "Object": if not tokens: raise UnexpectedEOFError("unexpected end of input") @@ -404,15 +411,11 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object": l = Spread() elif token == Operator("|"): expr = parse(tokens, PS["|"].pr) # TODO: make this work for larger arities - if not isinstance(expr, Function): - raise ParseError(f"expected function in match expression {expr!r}") - cases = [MatchCase(expr.arg, expr.body)] + cases = [build_match_case(expr)] while tokens and tokens[0] == Operator("|"): tokens.pop(0) expr = parse(tokens, PS["|"].pr) # TODO: make this work for larger arities - if not isinstance(expr, Function): - raise ParseError(f"expected function in match expression {expr!r}") - cases.append(MatchCase(expr.arg, expr.body)) + cases.append(build_match_case(expr)) l = MatchFunction(cases) elif isinstance(token, LeftParen): if isinstance(tokens[0], RightParen): From a295d051077e41e5cf0b985b71d963d1283a09af Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 20 Jan 2024 15:50:34 -0500 Subject: [PATCH 03/11] Fix tests --- scrapscript.py | 92 +++++++++++++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index d4099482..c3d30e00 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -370,8 +370,12 @@ def parse_assign(tokens: typing.List[Token], p: float = 0) -> "Assign": def build_match_case(expr: "Object") -> "MatchCase": if not isinstance(expr, Function): raise ParseError(f"expected function in match expression {expr!r}") - arg, body = expr.arg, expr.body - return MatchCase(arg, body) + pattern, body = expr.arg, expr.body + guard = None + if isinstance(pattern, Binop) and pattern.op == BinopKind.GUARD: + guard = pattern.right + pattern = pattern.left + return MatchCase(pattern, guard, body) def parse(tokens: typing.List[Token], p: float = 0) -> "Object": @@ -499,8 +503,6 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object": elif op == Operator("@"): # TODO: revisit whether to use @ or . for field access l = Access(l, parse(tokens, pr)) - elif op == Operator("guard"): - l = Guard(l, parse(tokens, pr)) else: assert not isinstance(op, Juxt) assert isinstance(op, Operator) @@ -687,6 +689,7 @@ class BinopKind(enum.Enum): HASTYPE = auto() PIPE = auto() REVERSE_PIPE = auto() + GUARD = auto() @classmethod def from_str(cls, x: str) -> "BinopKind": @@ -713,6 +716,7 @@ def from_str(cls, x: str) -> "BinopKind": ":": cls.HASTYPE, "|>": cls.PIPE, "<|": cls.REVERSE_PIPE, + "guard": cls.GUARD, }[x] @classmethod @@ -739,6 +743,7 @@ def to_str(cls, binop_kind: "BinopKind") -> str: cls.HASTYPE: ":", cls.PIPE: "|>", cls.REVERSE_PIPE: "<|", + cls.GUARD: "guard", }[binop_kind] @@ -873,15 +878,10 @@ def __str__(self) -> str: return f"EnvObject(keys={self.env.keys()})" -@dataclass(eq=True, frozen=True, unsafe_hash=True) -class Guard(Object): - pattern: Object - cond: Object - - @dataclass(eq=True, frozen=True, unsafe_hash=True) class MatchCase(Object): pattern: Object + guard: Optional[Object] body: Object def __str__(self) -> str: @@ -2203,7 +2203,7 @@ def test_parse_match_no_cases_raises_parse_error(self) -> None: def test_parse_match_one_case(self) -> None: self.assertEqual( parse([Operator("|"), IntLit(1), Operator("->"), IntLit(2)]), - MatchFunction([MatchCase(Int(1), Int(2))]), + MatchFunction([MatchCase(Int(1), None, Int(2))]), ) def test_parse_match_two_cases(self) -> None: @@ -2222,8 +2222,8 @@ def test_parse_match_two_cases(self) -> None: ), MatchFunction( [ - MatchCase(Int(1), Int(2)), - MatchCase(Int(2), Int(3)), + MatchCase(Int(1), None, Int(2)), + MatchCase(Int(2), None, Int(3)), ] ), ) @@ -2342,17 +2342,17 @@ def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None: def test_parse_symbol_returns_symbol(self) -> None: self.assertEqual(parse([SymbolToken("abc")]), Symbol("abc")) - def test_parse_guard(self) -> None: - self.assertEqual( - parse(tokenize("| x guard y -> x")), - MatchFunction([MatchCase(Guard(Var("x"), Var("y")), Var("x"))]), - ) + # def test_parse_guard(self) -> None: + # self.assertEqual( + # parse(tokenize("| x guard y -> x")), + # MatchFunction([MatchCase(Guard(Var("x"), Var("y")), Var("x"))]), + # ) - def test_parse_guard_exp(self) -> None: - self.assertEqual( - parse(tokenize("| x guard x==1 -> x")), - MatchFunction([MatchCase(Guard(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1))), Var("x"))]), - ) + # def test_parse_guard_exp(self) -> None: + # self.assertEqual( + # parse(tokenize("| x guard x==1 -> x")), + # MatchFunction([MatchCase(Guard(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1))), Var("x"))]), + # ) class MatchTests(unittest.TestCase): @@ -2530,7 +2530,8 @@ def test_parse_match_with_left_apply(self) -> None: ) ast = parse(tokens) self.assertEqual( - ast, MatchFunction([MatchCase(Var("a"), Apply(Var("b"), Var("c"))), MatchCase(Var("d"), Var("e"))]) + ast, + MatchFunction([MatchCase(Var("a"), None, Apply(Var("b"), Var("c"))), MatchCase(Var("d"), None, Var("e"))]), ) def test_parse_match_with_right_apply(self) -> None: @@ -2544,9 +2545,10 @@ def test_parse_match_with_right_apply(self) -> None: ast, MatchFunction( [ - MatchCase(Int(1), Int(19)), + MatchCase(Int(1), None, Int(19)), MatchCase( Var("a"), + None, Apply( Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))), Var("a"), @@ -2901,26 +2903,29 @@ def test_match_no_cases_raises_match_error(self) -> None: eval_exp({}, exp) def test_match_int_with_equal_int_matches(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(1)) + exp = Apply(MatchFunction([MatchCase(pattern=Int(1), guard=None, body=Int(2))]), Int(1)) self.assertEqual(eval_exp({}, exp), Int(2)) def test_match_int_with_inequal_int_raises_match_error(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(3)) + exp = Apply(MatchFunction([MatchCase(pattern=Int(1), guard=None, body=Int(2))]), Int(3)) with self.assertRaisesRegex(MatchError, "no matching cases"): eval_exp({}, exp) def test_match_string_with_equal_string_matches(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("a")) + exp = Apply(MatchFunction([MatchCase(pattern=String("a"), guard=None, body=String("b"))]), String("a")) self.assertEqual(eval_exp({}, exp), String("b")) def test_match_string_with_inequal_string_raises_match_error(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("c")) + exp = Apply(MatchFunction([MatchCase(pattern=String("a"), guard=None, body=String("b"))]), String("c")) with self.assertRaisesRegex(MatchError, "no matching cases"): eval_exp({}, exp) def test_match_falls_through_to_next(self) -> None: exp = Apply( - MatchFunction([MatchCase(pattern=Int(3), body=Int(4)), MatchCase(pattern=Int(1), body=Int(2))]), Int(1) + MatchFunction( + [MatchCase(pattern=Int(3), guard=None, body=Int(4)), MatchCase(pattern=Int(1), guard=None, body=Int(2))] + ), + Int(1), ) self.assertEqual(eval_exp({}, exp), Int(2)) @@ -2969,7 +2974,7 @@ def test_eval_apply_quote_returns_ast(self) -> None: self.assertIs(eval_exp({}, exp), ast) def test_eval_apply_closure_with_match_function_has_access_to_closure_vars(self) -> None: - ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), Var("x"))])), Int(2)) + ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), None, Var("x"))])), Int(2)) self.assertEqual(eval_exp({}, ast), Int(1)) def test_eval_less_returns_bool(self) -> None: @@ -3580,38 +3585,39 @@ def test_match_function(self) -> None: self.assertEqual(free_in(exp), {"x", "y"}) def test_match_case_int(self) -> None: - exp = MatchCase(Int(1), Var("x")) + exp = MatchCase(Int(1), None, Var("x")) self.assertEqual(free_in(exp), {"x"}) def test_match_case_var(self) -> None: - exp = MatchCase(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(Var("x"), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_match_case_list(self) -> None: - exp = MatchCase(List([Var("x")]), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(List([Var("x")]), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_match_case_list_spread(self) -> None: - exp = MatchCase(List([Spread()]), Binop(BinopKind.ADD, Var("xs"), Var("y"))) + exp = MatchCase(List([Spread()]), None, Binop(BinopKind.ADD, Var("xs"), Var("y"))) self.assertEqual(free_in(exp), {"xs", "y"}) def test_match_case_list_spread_name(self) -> None: - exp = MatchCase(List([Spread("xs")]), Binop(BinopKind.ADD, Var("xs"), Var("y"))) + exp = MatchCase(List([Spread("xs")]), None, Binop(BinopKind.ADD, Var("xs"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_match_case_record(self) -> None: exp = MatchCase( Record({"x": Int(1), "y": Var("y"), "a": Var("z")}), + None, Binop(BinopKind.ADD, Binop(BinopKind.ADD, Var("x"), Var("y")), Var("z")), ) self.assertEqual(free_in(exp), {"x"}) def test_match_case_record_spread(self) -> None: - exp = MatchCase(Record({"...": Spread()}), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(Record({"...": Spread()}), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"x", "y"}) def test_match_case_record_spread_name(self) -> None: - exp = MatchCase(Record({"...": Spread("x")}), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(Record({"...": Spread("x")}), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_apply(self) -> None: @@ -4336,12 +4342,14 @@ def test_pretty_print_envobject(self) -> None: self.assertEqual(str(obj), "EnvObject(keys=dict_keys(['x']))") def test_pretty_print_matchcase(self) -> None: - obj = MatchCase(pattern=Int(1), body=Int(2)) - self.assertEqual(str(obj), "MatchCase(pattern=Int(value=1), body=Int(value=2))") + obj = MatchCase(pattern=Int(1), guard=None, body=Int(2)) + self.assertEqual(str(obj), "MatchCase(pattern=Int(value=1), guard=None, body=Int(value=2))") def test_pretty_print_matchfunction(self) -> None: - obj = MatchFunction([MatchCase(Var("y"), Var("x"))]) - self.assertEqual(str(obj), "MatchFunction(cases=[MatchCase(pattern=Var(name='y'), body=Var(name='x'))])") + obj = MatchFunction([MatchCase(Var("y"), None, Var("x"))]) + self.assertEqual( + str(obj), "MatchFunction(cases=[MatchCase(pattern=Var(name='y'), guard=None, body=Var(name='x'))])" + ) def test_pretty_print_relocation(self) -> None: obj = Relocation("relocate") From eaff4219091773e87f67925eab8d85e52f1f9774 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sat, 20 Jan 2024 15:50:39 -0500 Subject: [PATCH 04/11] wip --- scrapscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scrapscript.py b/scrapscript.py index c3d30e00..fe1ac771 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3.10 +#!/usr/bin/env python3.8 import argparse import base64 import code From fe1791431643ea6a14d2eda4ac8298e93e073f6c Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 11:55:14 -0500 Subject: [PATCH 05/11] Fix tests --- scrapscript.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index fe1ac771..b897e850 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -2342,17 +2342,28 @@ def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None: def test_parse_symbol_returns_symbol(self) -> None: self.assertEqual(parse([SymbolToken("abc")]), Symbol("abc")) - # def test_parse_guard(self) -> None: - # self.assertEqual( - # parse(tokenize("| x guard y -> x")), - # MatchFunction([MatchCase(Guard(Var("x"), Var("y")), Var("x"))]), - # ) - - # def test_parse_guard_exp(self) -> None: - # self.assertEqual( - # parse(tokenize("| x guard x==1 -> x")), - # MatchFunction([MatchCase(Guard(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1))), Var("x"))]), - # ) + def test_parse_guard(self) -> None: + self.assertEqual( + parse(tokenize("| x guard y -> x")), + MatchFunction([MatchCase(Var("x"), Var("y"), Var("x"))]), + ) + + def test_parse_guard_exp(self) -> None: + self.assertEqual( + parse(tokenize("| x guard x==1 -> x")), + MatchFunction([MatchCase(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1)), Var("x"))]), + ) + + def test_parse_multiple_guards(self) -> None: + self.assertEqual( + parse(tokenize("| x guard y -> x | a guard b -> 1")), + MatchFunction( + [ + MatchCase(Var("x"), Var("y"), Var("x")), + MatchCase(Var("a"), Var("b"), Int(1)), + ] + ), + ) class MatchTests(unittest.TestCase): From e99c777ab72d3fa0b05fa411832a62bdab5778f3 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 11:59:02 -0500 Subject: [PATCH 06/11] wip --- scrapscript.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index b897e850..2b30f0bd 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -305,7 +305,6 @@ def xp(n: float) -> Prec: "::": lp(2000), "@": rp(1001), "": rp(1000), - "guard": rp(5.5), ">>": lp(14), "<<": lp(14), "^": rp(13), @@ -329,6 +328,7 @@ def xp(n: float) -> Prec: "||": rp(7), "|>": rp(6), "<|": lp(6), + "guard": rp(5.5), "->": lp(5), "|": rp(4.5), ":": lp(4.5), @@ -345,7 +345,9 @@ def xp(n: float) -> Prec: HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values()) -OPER_CHARS = set("".join(PS.keys())) - set("guard") +# TODO(max): Consider making "guard" an operator with only punctuation (but +# leave syntax-level "guard" keyword) +OPER_CHARS = set(c for c in "".join(PS.keys()) if not c.isalpha()) assert " " not in OPER_CHARS @@ -2365,6 +2367,12 @@ def test_parse_multiple_guards(self) -> None: ), ) + def test_parse_guard_pipe(self) -> None: + self.assertEqual( + parse(tokenize("| x guard x |> f -> x")), + MatchFunction([MatchCase(Var("x"), Apply(Var("f"), Var("x")), Var("x"))]), + ) + class MatchTests(unittest.TestCase): def test_match_with_equal_ints_returns_empty_dict(self) -> None: From 5b98e3bee3b51f78ce355d548c6a9a7d28525307 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 12:14:59 -0500 Subject: [PATCH 07/11] wip --- scrapscript.py | 117 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 109 insertions(+), 8 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 2b30f0bd..16dfe9d2 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -1093,17 +1093,26 @@ class MatchError(Exception): pass -def match(obj: Object, pattern: Object) -> Optional[Env]: +def match_guard(env: Env, guard: Optional[Object]) -> bool: + if guard is None: + return True + return eval_exp(env, guard) == Symbol("true") + + +def match(obj: Object, pattern: Object, env: Optional[Env] = None, guard: Optional[Object] = None) -> Optional[Env]: + if env is None: + env = {} if isinstance(pattern, Int): - return {} if isinstance(obj, Int) and obj.value == pattern.value else None + return {} if isinstance(obj, Int) and obj.value == pattern.value and match_guard(env, guard) else None if isinstance(pattern, Float): raise MatchError("pattern matching is not supported for Floats") if isinstance(pattern, String): - return {} if isinstance(obj, String) and obj.value == pattern.value else None + return {} if isinstance(obj, String) and obj.value == pattern.value and match_guard(env, guard) else None if isinstance(pattern, Var): - return {pattern.name: obj} + env = {**env, pattern.name: obj} + return env if match_guard(env, guard) else None if isinstance(pattern, Symbol): - return {} if isinstance(obj, Symbol) and obj.value == pattern.value else None + return {} if isinstance(obj, Symbol) and obj.value == pattern.value and match_guard(env, guard) else None if isinstance(pattern, Record): if not isinstance(obj, Record): return None @@ -1123,7 +1132,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]: result.update(part) if not use_spread and len(pattern.data) != len(obj.data): return None - return result + return result if match_guard(result, guard) else None if isinstance(pattern, List): if not isinstance(obj, List): return None @@ -1146,7 +1155,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]: result.update(part) if not use_spread and len(pattern.items) != len(obj.items): return None - return result + return result if match_guard(result, guard) else None raise NotImplementedError(f"match not implemented for {type(pattern).__name__}") @@ -1272,7 +1281,7 @@ def eval_exp(env: Env, exp: Object) -> Object: elif isinstance(callee.func, MatchFunction): arg = eval_exp(env, exp.arg) for case in callee.func.cases: - m = match(arg, case.pattern) + m = match(arg, case.pattern, env, case.guard) if m is None: continue return eval_exp({**callee.env, **m}, case.body) @@ -3257,6 +3266,98 @@ def test_match_var_binds_var(self) -> None: Int(3), ) + def test_match_guard_closure_var(self) -> None: + self.assertEqual( + self._run( + """ + id 1 + . id = + | x guard cond -> "one" + | x -> "idk" + . cond = 2 + """ + ), + String("idk"), + ) + + def test_match_record_guard_pass(self) -> None: + self.assertEqual( + self._run( + """ + id {cond=#true} + . id = + | {cond=cond} guard cond -> "yes" + | x -> "no" + """ + ), + String("yes"), + ) + + def test_match_record_guard_fail(self) -> None: + self.assertEqual( + self._run( + """ + id {cond=#false} + . id = + | {cond=cond} guard cond -> "yes" + | x -> "no" + """ + ), + String("no"), + ) + + def test_match_list_guard_pass(self) -> None: + self.assertEqual( + self._run( + """ + id [#true] + . id = + | [cond] guard cond -> "yes" + | x -> "no" + """ + ), + String("yes"), + ) + + def test_match_list_guard_fail(self) -> None: + self.assertEqual( + self._run( + """ + id [#false] + . id = + | [cond] guard cond -> "yes" + | x -> "no" + """ + ), + String("no"), + ) + + def test_match_guard_pass(self) -> None: + self.assertEqual( + self._run( + """ + id 1 + . id = + | x guard x==1 -> "one" + | x -> "idk" + """ + ), + String("one"), + ) + + def test_match_guard_fail(self) -> None: + self.assertEqual( + self._run( + """ + id 2 + . id = + | x guard x==1 -> "one" + | x -> "idk" + """ + ), + String("idk"), + ) + def test_match_var_binds_first_arm(self) -> None: self.assertEqual( self._run( From 7ba29ba60d9880fd3953cd02c941bc7e7f7c404d Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 12:16:33 -0500 Subject: [PATCH 08/11] wip --- scrapscript.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scrapscript.py b/scrapscript.py index 16dfe9d2..01fed36c 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -1281,9 +1281,11 @@ def eval_exp(env: Env, exp: Object) -> Object: elif isinstance(callee.func, MatchFunction): arg = eval_exp(env, exp.arg) for case in callee.func.cases: - m = match(arg, case.pattern, env, case.guard) + m = match(arg, case.pattern) if m is None: continue + if case.guard is not None and eval_exp({**env, **m}, case.guard) != Symbol("true"): + continue return eval_exp({**callee.env, **m}, case.body) raise MatchError("no matching cases") else: From 412aacb93b75b197fac6687e8e4e129e4875a265 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 12:17:42 -0500 Subject: [PATCH 09/11] wip --- scrapscript.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/scrapscript.py b/scrapscript.py index 01fed36c..23bc7e16 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -305,6 +305,7 @@ def xp(n: float) -> Prec: "::": lp(2000), "@": rp(1001), "": rp(1000), + "guard": rp(5.5), ">>": lp(14), "<<": lp(14), "^": rp(13), @@ -328,7 +329,6 @@ def xp(n: float) -> Prec: "||": rp(7), "|>": rp(6), "<|": lp(6), - "guard": rp(5.5), "->": lp(5), "|": rp(4.5), ":": lp(4.5), @@ -345,9 +345,7 @@ def xp(n: float) -> Prec: HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values()) -# TODO(max): Consider making "guard" an operator with only punctuation (but -# leave syntax-level "guard" keyword) -OPER_CHARS = set(c for c in "".join(PS.keys()) if not c.isalpha()) +OPER_CHARS = set("".join(PS.keys())) - set("guard") assert " " not in OPER_CHARS @@ -1093,26 +1091,17 @@ class MatchError(Exception): pass -def match_guard(env: Env, guard: Optional[Object]) -> bool: - if guard is None: - return True - return eval_exp(env, guard) == Symbol("true") - - -def match(obj: Object, pattern: Object, env: Optional[Env] = None, guard: Optional[Object] = None) -> Optional[Env]: - if env is None: - env = {} +def match(obj: Object, pattern: Object) -> Optional[Env]: if isinstance(pattern, Int): - return {} if isinstance(obj, Int) and obj.value == pattern.value and match_guard(env, guard) else None + return {} if isinstance(obj, Int) and obj.value == pattern.value else None if isinstance(pattern, Float): raise MatchError("pattern matching is not supported for Floats") if isinstance(pattern, String): - return {} if isinstance(obj, String) and obj.value == pattern.value and match_guard(env, guard) else None + return {} if isinstance(obj, String) and obj.value == pattern.value else None if isinstance(pattern, Var): - env = {**env, pattern.name: obj} - return env if match_guard(env, guard) else None + return {pattern.name: obj} if isinstance(pattern, Symbol): - return {} if isinstance(obj, Symbol) and obj.value == pattern.value and match_guard(env, guard) else None + return {} if isinstance(obj, Symbol) and obj.value == pattern.value else None if isinstance(pattern, Record): if not isinstance(obj, Record): return None @@ -1132,7 +1121,7 @@ def match(obj: Object, pattern: Object, env: Optional[Env] = None, guard: Option result.update(part) if not use_spread and len(pattern.data) != len(obj.data): return None - return result if match_guard(result, guard) else None + return result if isinstance(pattern, List): if not isinstance(obj, List): return None @@ -1155,7 +1144,7 @@ def match(obj: Object, pattern: Object, env: Optional[Env] = None, guard: Option result.update(part) if not use_spread and len(pattern.items) != len(obj.items): return None - return result if match_guard(result, guard) else None + return result raise NotImplementedError(f"match not implemented for {type(pattern).__name__}") @@ -2378,12 +2367,6 @@ def test_parse_multiple_guards(self) -> None: ), ) - def test_parse_guard_pipe(self) -> None: - self.assertEqual( - parse(tokenize("| x guard x |> f -> x")), - MatchFunction([MatchCase(Var("x"), Apply(Var("f"), Var("x")), Var("x"))]), - ) - class MatchTests(unittest.TestCase): def test_match_with_equal_ints_returns_empty_dict(self) -> None: From 2fc3e66609dfe4c6e4e824598faedd7ad45b9e92 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 12:22:00 -0500 Subject: [PATCH 10/11] Don't special case guard --- scrapscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scrapscript.py b/scrapscript.py index 23bc7e16..db858bbc 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -345,7 +345,7 @@ def xp(n: float) -> Prec: HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values()) -OPER_CHARS = set("".join(PS.keys())) - set("guard") +OPER_CHARS = set(c for c in "".join(PS.keys()) if not c.isalpha()) assert " " not in OPER_CHARS From c1f50cac79bcf4425a2f3875a3d19124ee4bd0ee Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 21 Jan 2024 09:24:34 -0800 Subject: [PATCH 11/11] Update scrapscript.py --- scrapscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scrapscript.py b/scrapscript.py index db858bbc..f44491ae 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3.8 +#!/usr/bin/env python3.10 import argparse import base64 import code