@@ -1093,17 +1093,26 @@ class MatchError(Exception):
10931093 pass
10941094
10951095
1096- def match (obj : Object , pattern : Object ) -> Optional [Env ]:
1096+ def match_guard (env : Env , guard : Optional [Object ]) -> bool :
1097+ if guard is None :
1098+ return True
1099+ return eval_exp (env , guard ) == Symbol ("true" )
1100+
1101+
1102+ def match (obj : Object , pattern : Object , env : Optional [Env ] = None , guard : Optional [Object ] = None ) -> Optional [Env ]:
1103+ if env is None :
1104+ env = {}
10971105 if isinstance (pattern , Int ):
1098- return {} if isinstance (obj , Int ) and obj .value == pattern .value else None
1106+ return {} if isinstance (obj , Int ) and obj .value == pattern .value and match_guard ( env , guard ) else None
10991107 if isinstance (pattern , Float ):
11001108 raise MatchError ("pattern matching is not supported for Floats" )
11011109 if isinstance (pattern , String ):
1102- return {} if isinstance (obj , String ) and obj .value == pattern .value else None
1110+ return {} if isinstance (obj , String ) and obj .value == pattern .value and match_guard ( env , guard ) else None
11031111 if isinstance (pattern , Var ):
1104- return {pattern .name : obj }
1112+ env = {** env , pattern .name : obj }
1113+ return env if match_guard (env , guard ) else None
11051114 if isinstance (pattern , Symbol ):
1106- return {} if isinstance (obj , Symbol ) and obj .value == pattern .value else None
1115+ return {} if isinstance (obj , Symbol ) and obj .value == pattern .value and match_guard ( env , guard ) else None
11071116 if isinstance (pattern , Record ):
11081117 if not isinstance (obj , Record ):
11091118 return None
@@ -1123,7 +1132,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11231132 result .update (part )
11241133 if not use_spread and len (pattern .data ) != len (obj .data ):
11251134 return None
1126- return result
1135+ return result if match_guard ( result , guard ) else None
11271136 if isinstance (pattern , List ):
11281137 if not isinstance (obj , List ):
11291138 return None
@@ -1146,7 +1155,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11461155 result .update (part )
11471156 if not use_spread and len (pattern .items ) != len (obj .items ):
11481157 return None
1149- return result
1158+ return result if match_guard ( result , guard ) else None
11501159 raise NotImplementedError (f"match not implemented for { type (pattern ).__name__ } " )
11511160
11521161
@@ -1272,7 +1281,7 @@ def eval_exp(env: Env, exp: Object) -> Object:
12721281 elif isinstance (callee .func , MatchFunction ):
12731282 arg = eval_exp (env , exp .arg )
12741283 for case in callee .func .cases :
1275- m = match (arg , case .pattern )
1284+ m = match (arg , case .pattern , env , case . guard )
12761285 if m is None :
12771286 continue
12781287 return eval_exp ({** callee .env , ** m }, case .body )
@@ -3257,6 +3266,98 @@ def test_match_var_binds_var(self) -> None:
32573266 Int (3 ),
32583267 )
32593268
3269+ def test_match_guard_closure_var (self ) -> None :
3270+ self .assertEqual (
3271+ self ._run (
3272+ """
3273+ id 1
3274+ . id =
3275+ | x guard cond -> "one"
3276+ | x -> "idk"
3277+ . cond = 2
3278+ """
3279+ ),
3280+ String ("idk" ),
3281+ )
3282+
3283+ def test_match_record_guard_pass (self ) -> None :
3284+ self .assertEqual (
3285+ self ._run (
3286+ """
3287+ id {cond=#true}
3288+ . id =
3289+ | {cond=cond} guard cond -> "yes"
3290+ | x -> "no"
3291+ """
3292+ ),
3293+ String ("yes" ),
3294+ )
3295+
3296+ def test_match_record_guard_fail (self ) -> None :
3297+ self .assertEqual (
3298+ self ._run (
3299+ """
3300+ id {cond=#false}
3301+ . id =
3302+ | {cond=cond} guard cond -> "yes"
3303+ | x -> "no"
3304+ """
3305+ ),
3306+ String ("no" ),
3307+ )
3308+
3309+ def test_match_list_guard_pass (self ) -> None :
3310+ self .assertEqual (
3311+ self ._run (
3312+ """
3313+ id [#true]
3314+ . id =
3315+ | [cond] guard cond -> "yes"
3316+ | x -> "no"
3317+ """
3318+ ),
3319+ String ("yes" ),
3320+ )
3321+
3322+ def test_match_list_guard_fail (self ) -> None :
3323+ self .assertEqual (
3324+ self ._run (
3325+ """
3326+ id [#false]
3327+ . id =
3328+ | [cond] guard cond -> "yes"
3329+ | x -> "no"
3330+ """
3331+ ),
3332+ String ("no" ),
3333+ )
3334+
3335+ def test_match_guard_pass (self ) -> None :
3336+ self .assertEqual (
3337+ self ._run (
3338+ """
3339+ id 1
3340+ . id =
3341+ | x guard x==1 -> "one"
3342+ | x -> "idk"
3343+ """
3344+ ),
3345+ String ("one" ),
3346+ )
3347+
3348+ def test_match_guard_fail (self ) -> None :
3349+ self .assertEqual (
3350+ self ._run (
3351+ """
3352+ id 2
3353+ . id =
3354+ | x guard x==1 -> "one"
3355+ | x -> "idk"
3356+ """
3357+ ),
3358+ String ("idk" ),
3359+ )
3360+
32603361 def test_match_var_binds_first_arm (self ) -> None :
32613362 self .assertEqual (
32623363 self ._run (
0 commit comments