@@ -1096,17 +1096,26 @@ class MatchError(Exception):
10961096 pass
10971097
10981098
1099- def match (obj : Object , pattern : Object ) -> Optional [Env ]:
1099+ def match_guard (env : Env , guard : Optional [Object ]) -> bool :
1100+ if guard is None :
1101+ return True
1102+ return eval_exp (env , guard ) == Symbol ("true" )
1103+
1104+
1105+ def match (obj : Object , pattern : Object , env : Optional [Env ] = None , guard : Optional [Object ] = None ) -> Optional [Env ]:
1106+ if env is None :
1107+ env = {}
11001108 if isinstance (pattern , Int ):
1101- return {} if isinstance (obj , Int ) and obj .value == pattern .value else None
1109+ return {} if isinstance (obj , Int ) and obj .value == pattern .value and match_guard ( env , guard ) else None
11021110 if isinstance (pattern , Float ):
11031111 raise MatchError ("pattern matching is not supported for Floats" )
11041112 if isinstance (pattern , String ):
1105- return {} if isinstance (obj , String ) and obj .value == pattern .value else None
1113+ return {} if isinstance (obj , String ) and obj .value == pattern .value and match_guard ( env , guard ) else None
11061114 if isinstance (pattern , Var ):
1107- return {pattern .name : obj }
1115+ env = {** env , pattern .name : obj }
1116+ return env if match_guard (env , guard ) else None
11081117 if isinstance (pattern , Symbol ):
1109- return {} if isinstance (obj , Symbol ) and obj .value == pattern .value else None
1118+ return {} if isinstance (obj , Symbol ) and obj .value == pattern .value and match_guard ( env , guard ) else None
11101119 if isinstance (pattern , Record ):
11111120 if not isinstance (obj , Record ):
11121121 return None
@@ -1126,7 +1135,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11261135 result .update (part )
11271136 if not use_spread and len (pattern .data ) != len (obj .data ):
11281137 return None
1129- return result
1138+ return result if match_guard ( result , guard ) else None
11301139 if isinstance (pattern , List ):
11311140 if not isinstance (obj , List ):
11321141 return None
@@ -1149,7 +1158,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11491158 result .update (part )
11501159 if not use_spread and len (pattern .items ) != len (obj .items ):
11511160 return None
1152- return result
1161+ return result if match_guard ( result , guard ) else None
11531162 raise NotImplementedError (f"match not implemented for { type (pattern ).__name__ } " )
11541163
11551164
@@ -1273,7 +1282,7 @@ def eval_exp(env: Env, exp: Object) -> Object:
12731282 elif isinstance (callee .func , MatchFunction ):
12741283 arg = eval_exp (env , exp .arg )
12751284 for case in callee .func .cases :
1276- m = match (arg , case .pattern )
1285+ m = match (arg , case .pattern , env , case . guard )
12771286 if m is None :
12781287 continue
12791288 return eval_exp ({** callee .env , ** m }, case .body )
@@ -3256,6 +3265,98 @@ def test_match_var_binds_var(self) -> None:
32563265 Int (3 ),
32573266 )
32583267
3268+ def test_match_guard_closure_var (self ) -> None :
3269+ self .assertEqual (
3270+ self ._run (
3271+ """
3272+ id 1
3273+ . id =
3274+ | x guard cond -> "one"
3275+ | x -> "idk"
3276+ . cond = 2
3277+ """
3278+ ),
3279+ String ("idk" ),
3280+ )
3281+
3282+ def test_match_record_guard_pass (self ) -> None :
3283+ self .assertEqual (
3284+ self ._run (
3285+ """
3286+ id {cond=#true}
3287+ . id =
3288+ | {cond=cond} guard cond -> "yes"
3289+ | x -> "no"
3290+ """
3291+ ),
3292+ String ("yes" ),
3293+ )
3294+
3295+ def test_match_record_guard_fail (self ) -> None :
3296+ self .assertEqual (
3297+ self ._run (
3298+ """
3299+ id {cond=#false}
3300+ . id =
3301+ | {cond=cond} guard cond -> "yes"
3302+ | x -> "no"
3303+ """
3304+ ),
3305+ String ("no" ),
3306+ )
3307+
3308+ def test_match_list_guard_pass (self ) -> None :
3309+ self .assertEqual (
3310+ self ._run (
3311+ """
3312+ id [#true]
3313+ . id =
3314+ | [cond] guard cond -> "yes"
3315+ | x -> "no"
3316+ """
3317+ ),
3318+ String ("yes" ),
3319+ )
3320+
3321+ def test_match_list_guard_fail (self ) -> None :
3322+ self .assertEqual (
3323+ self ._run (
3324+ """
3325+ id [#false]
3326+ . id =
3327+ | [cond] guard cond -> "yes"
3328+ | x -> "no"
3329+ """
3330+ ),
3331+ String ("no" ),
3332+ )
3333+
3334+ def test_match_guard_pass (self ) -> None :
3335+ self .assertEqual (
3336+ self ._run (
3337+ """
3338+ id 1
3339+ . id =
3340+ | x guard x==1 -> "one"
3341+ | x -> "idk"
3342+ """
3343+ ),
3344+ String ("one" ),
3345+ )
3346+
3347+ def test_match_guard_fail (self ) -> None :
3348+ self .assertEqual (
3349+ self ._run (
3350+ """
3351+ id 2
3352+ . id =
3353+ | x guard x==1 -> "one"
3354+ | x -> "idk"
3355+ """
3356+ ),
3357+ String ("idk" ),
3358+ )
3359+
32593360 def test_match_var_binds_first_arm (self ) -> None :
32603361 self .assertEqual (
32613362 self ._run (
0 commit comments