Skip to content

Commit 6fa9b87

Browse files
committed
wip
1 parent 4e45d06 commit 6fa9b87

File tree

1 file changed

+109
-8
lines changed

1 file changed

+109
-8
lines changed

scrapscript.py

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,17 +1096,26 @@ class MatchError(Exception):
10961096
pass
10971097

10981098

1099-
def match(obj: Object, pattern: Object) -> Optional[Env]:
1099+
def match_guard(env: Env, guard: Optional[Object]) -> bool:
1100+
if guard is None:
1101+
return True
1102+
return eval_exp(env, guard) == Symbol("true")
1103+
1104+
1105+
def match(obj: Object, pattern: Object, env: Optional[Env] = None, guard: Optional[Object] = None) -> Optional[Env]:
1106+
if env is None:
1107+
env = {}
11001108
if isinstance(pattern, Int):
1101-
return {} if isinstance(obj, Int) and obj.value == pattern.value else None
1109+
return {} if isinstance(obj, Int) and obj.value == pattern.value and match_guard(env, guard) else None
11021110
if isinstance(pattern, Float):
11031111
raise MatchError("pattern matching is not supported for Floats")
11041112
if isinstance(pattern, String):
1105-
return {} if isinstance(obj, String) and obj.value == pattern.value else None
1113+
return {} if isinstance(obj, String) and obj.value == pattern.value and match_guard(env, guard) else None
11061114
if isinstance(pattern, Var):
1107-
return {pattern.name: obj}
1115+
env = {**env, pattern.name: obj}
1116+
return env if match_guard(env, guard) else None
11081117
if isinstance(pattern, Symbol):
1109-
return {} if isinstance(obj, Symbol) and obj.value == pattern.value else None
1118+
return {} if isinstance(obj, Symbol) and obj.value == pattern.value and match_guard(env, guard) else None
11101119
if isinstance(pattern, Record):
11111120
if not isinstance(obj, Record):
11121121
return None
@@ -1126,7 +1135,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11261135
result.update(part)
11271136
if not use_spread and len(pattern.data) != len(obj.data):
11281137
return None
1129-
return result
1138+
return result if match_guard(result, guard) else None
11301139
if isinstance(pattern, List):
11311140
if not isinstance(obj, List):
11321141
return None
@@ -1149,7 +1158,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11491158
result.update(part)
11501159
if not use_spread and len(pattern.items) != len(obj.items):
11511160
return None
1152-
return result
1161+
return result if match_guard(result, guard) else None
11531162
raise NotImplementedError(f"match not implemented for {type(pattern).__name__}")
11541163

11551164

@@ -1273,7 +1282,7 @@ def eval_exp(env: Env, exp: Object) -> Object:
12731282
elif isinstance(callee.func, MatchFunction):
12741283
arg = eval_exp(env, exp.arg)
12751284
for case in callee.func.cases:
1276-
m = match(arg, case.pattern)
1285+
m = match(arg, case.pattern, env, case.guard)
12771286
if m is None:
12781287
continue
12791288
return eval_exp({**callee.env, **m}, case.body)
@@ -3256,6 +3265,98 @@ def test_match_var_binds_var(self) -> None:
32563265
Int(3),
32573266
)
32583267

3268+
def test_match_guard_closure_var(self) -> None:
3269+
self.assertEqual(
3270+
self._run(
3271+
"""
3272+
id 1
3273+
. id =
3274+
| x guard cond -> "one"
3275+
| x -> "idk"
3276+
. cond = 2
3277+
"""
3278+
),
3279+
String("idk"),
3280+
)
3281+
3282+
def test_match_record_guard_pass(self) -> None:
3283+
self.assertEqual(
3284+
self._run(
3285+
"""
3286+
id {cond=#true}
3287+
. id =
3288+
| {cond=cond} guard cond -> "yes"
3289+
| x -> "no"
3290+
"""
3291+
),
3292+
String("yes"),
3293+
)
3294+
3295+
def test_match_record_guard_fail(self) -> None:
3296+
self.assertEqual(
3297+
self._run(
3298+
"""
3299+
id {cond=#false}
3300+
. id =
3301+
| {cond=cond} guard cond -> "yes"
3302+
| x -> "no"
3303+
"""
3304+
),
3305+
String("no"),
3306+
)
3307+
3308+
def test_match_list_guard_pass(self) -> None:
3309+
self.assertEqual(
3310+
self._run(
3311+
"""
3312+
id [#true]
3313+
. id =
3314+
| [cond] guard cond -> "yes"
3315+
| x -> "no"
3316+
"""
3317+
),
3318+
String("yes"),
3319+
)
3320+
3321+
def test_match_list_guard_fail(self) -> None:
3322+
self.assertEqual(
3323+
self._run(
3324+
"""
3325+
id [#false]
3326+
. id =
3327+
| [cond] guard cond -> "yes"
3328+
| x -> "no"
3329+
"""
3330+
),
3331+
String("no"),
3332+
)
3333+
3334+
def test_match_guard_pass(self) -> None:
3335+
self.assertEqual(
3336+
self._run(
3337+
"""
3338+
id 1
3339+
. id =
3340+
| x guard x==1 -> "one"
3341+
| x -> "idk"
3342+
"""
3343+
),
3344+
String("one"),
3345+
)
3346+
3347+
def test_match_guard_fail(self) -> None:
3348+
self.assertEqual(
3349+
self._run(
3350+
"""
3351+
id 2
3352+
. id =
3353+
| x guard x==1 -> "one"
3354+
| x -> "idk"
3355+
"""
3356+
),
3357+
String("idk"),
3358+
)
3359+
32593360
def test_match_var_binds_first_arm(self) -> None:
32603361
self.assertEqual(
32613362
self._run(

0 commit comments

Comments
 (0)