Skip to content

Commit 5b98e3b

Browse files
committed
wip
1 parent e99c777 commit 5b98e3b

File tree

1 file changed

+109
-8
lines changed

1 file changed

+109
-8
lines changed

scrapscript.py

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,17 +1093,26 @@ class MatchError(Exception):
10931093
pass
10941094

10951095

1096-
def match(obj: Object, pattern: Object) -> Optional[Env]:
1096+
def match_guard(env: Env, guard: Optional[Object]) -> bool:
1097+
if guard is None:
1098+
return True
1099+
return eval_exp(env, guard) == Symbol("true")
1100+
1101+
1102+
def match(obj: Object, pattern: Object, env: Optional[Env] = None, guard: Optional[Object] = None) -> Optional[Env]:
1103+
if env is None:
1104+
env = {}
10971105
if isinstance(pattern, Int):
1098-
return {} if isinstance(obj, Int) and obj.value == pattern.value else None
1106+
return {} if isinstance(obj, Int) and obj.value == pattern.value and match_guard(env, guard) else None
10991107
if isinstance(pattern, Float):
11001108
raise MatchError("pattern matching is not supported for Floats")
11011109
if isinstance(pattern, String):
1102-
return {} if isinstance(obj, String) and obj.value == pattern.value else None
1110+
return {} if isinstance(obj, String) and obj.value == pattern.value and match_guard(env, guard) else None
11031111
if isinstance(pattern, Var):
1104-
return {pattern.name: obj}
1112+
env = {**env, pattern.name: obj}
1113+
return env if match_guard(env, guard) else None
11051114
if isinstance(pattern, Symbol):
1106-
return {} if isinstance(obj, Symbol) and obj.value == pattern.value else None
1115+
return {} if isinstance(obj, Symbol) and obj.value == pattern.value and match_guard(env, guard) else None
11071116
if isinstance(pattern, Record):
11081117
if not isinstance(obj, Record):
11091118
return None
@@ -1123,7 +1132,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11231132
result.update(part)
11241133
if not use_spread and len(pattern.data) != len(obj.data):
11251134
return None
1126-
return result
1135+
return result if match_guard(result, guard) else None
11271136
if isinstance(pattern, List):
11281137
if not isinstance(obj, List):
11291138
return None
@@ -1146,7 +1155,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11461155
result.update(part)
11471156
if not use_spread and len(pattern.items) != len(obj.items):
11481157
return None
1149-
return result
1158+
return result if match_guard(result, guard) else None
11501159
raise NotImplementedError(f"match not implemented for {type(pattern).__name__}")
11511160

11521161

@@ -1272,7 +1281,7 @@ def eval_exp(env: Env, exp: Object) -> Object:
12721281
elif isinstance(callee.func, MatchFunction):
12731282
arg = eval_exp(env, exp.arg)
12741283
for case in callee.func.cases:
1275-
m = match(arg, case.pattern)
1284+
m = match(arg, case.pattern, env, case.guard)
12761285
if m is None:
12771286
continue
12781287
return eval_exp({**callee.env, **m}, case.body)
@@ -3257,6 +3266,98 @@ def test_match_var_binds_var(self) -> None:
32573266
Int(3),
32583267
)
32593268

3269+
def test_match_guard_closure_var(self) -> None:
3270+
self.assertEqual(
3271+
self._run(
3272+
"""
3273+
id 1
3274+
. id =
3275+
| x guard cond -> "one"
3276+
| x -> "idk"
3277+
. cond = 2
3278+
"""
3279+
),
3280+
String("idk"),
3281+
)
3282+
3283+
def test_match_record_guard_pass(self) -> None:
3284+
self.assertEqual(
3285+
self._run(
3286+
"""
3287+
id {cond=#true}
3288+
. id =
3289+
| {cond=cond} guard cond -> "yes"
3290+
| x -> "no"
3291+
"""
3292+
),
3293+
String("yes"),
3294+
)
3295+
3296+
def test_match_record_guard_fail(self) -> None:
3297+
self.assertEqual(
3298+
self._run(
3299+
"""
3300+
id {cond=#false}
3301+
. id =
3302+
| {cond=cond} guard cond -> "yes"
3303+
| x -> "no"
3304+
"""
3305+
),
3306+
String("no"),
3307+
)
3308+
3309+
def test_match_list_guard_pass(self) -> None:
3310+
self.assertEqual(
3311+
self._run(
3312+
"""
3313+
id [#true]
3314+
. id =
3315+
| [cond] guard cond -> "yes"
3316+
| x -> "no"
3317+
"""
3318+
),
3319+
String("yes"),
3320+
)
3321+
3322+
def test_match_list_guard_fail(self) -> None:
3323+
self.assertEqual(
3324+
self._run(
3325+
"""
3326+
id [#false]
3327+
. id =
3328+
| [cond] guard cond -> "yes"
3329+
| x -> "no"
3330+
"""
3331+
),
3332+
String("no"),
3333+
)
3334+
3335+
def test_match_guard_pass(self) -> None:
3336+
self.assertEqual(
3337+
self._run(
3338+
"""
3339+
id 1
3340+
. id =
3341+
| x guard x==1 -> "one"
3342+
| x -> "idk"
3343+
"""
3344+
),
3345+
String("one"),
3346+
)
3347+
3348+
def test_match_guard_fail(self) -> None:
3349+
self.assertEqual(
3350+
self._run(
3351+
"""
3352+
id 2
3353+
. id =
3354+
| x guard x==1 -> "one"
3355+
| x -> "idk"
3356+
"""
3357+
),
3358+
String("idk"),
3359+
)
3360+
32603361
def test_match_var_binds_first_arm(self) -> None:
32613362
self.assertEqual(
32623363
self._run(

0 commit comments

Comments
 (0)