@@ -1096,17 +1096,26 @@ class MatchError(Exception):
1096
1096
pass
1097
1097
1098
1098
1099
- def match (obj : Object , pattern : Object ) -> Optional [Env ]:
1099
+ def match_guard (env : Env , guard : Optional [Object ]) -> bool :
1100
+ if guard is None :
1101
+ return True
1102
+ return eval_exp (env , guard ) == Symbol ("true" )
1103
+
1104
+
1105
+ def match (obj : Object , pattern : Object , env : Optional [Env ] = None , guard : Optional [Object ] = None ) -> Optional [Env ]:
1106
+ if env is None :
1107
+ env = {}
1100
1108
if isinstance (pattern , Int ):
1101
- return {} if isinstance (obj , Int ) and obj .value == pattern .value else None
1109
+ return {} if isinstance (obj , Int ) and obj .value == pattern .value and match_guard ( env , guard ) else None
1102
1110
if isinstance (pattern , Float ):
1103
1111
raise MatchError ("pattern matching is not supported for Floats" )
1104
1112
if isinstance (pattern , String ):
1105
- return {} if isinstance (obj , String ) and obj .value == pattern .value else None
1113
+ return {} if isinstance (obj , String ) and obj .value == pattern .value and match_guard ( env , guard ) else None
1106
1114
if isinstance (pattern , Var ):
1107
- return {pattern .name : obj }
1115
+ env = {** env , pattern .name : obj }
1116
+ return env if match_guard (env , guard ) else None
1108
1117
if isinstance (pattern , Symbol ):
1109
- return {} if isinstance (obj , Symbol ) and obj .value == pattern .value else None
1118
+ return {} if isinstance (obj , Symbol ) and obj .value == pattern .value and match_guard ( env , guard ) else None
1110
1119
if isinstance (pattern , Record ):
1111
1120
if not isinstance (obj , Record ):
1112
1121
return None
@@ -1126,7 +1135,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
1126
1135
result .update (part )
1127
1136
if not use_spread and len (pattern .data ) != len (obj .data ):
1128
1137
return None
1129
- return result
1138
+ return result if match_guard ( result , guard ) else None
1130
1139
if isinstance (pattern , List ):
1131
1140
if not isinstance (obj , List ):
1132
1141
return None
@@ -1149,7 +1158,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
1149
1158
result .update (part )
1150
1159
if not use_spread and len (pattern .items ) != len (obj .items ):
1151
1160
return None
1152
- return result
1161
+ return result if match_guard ( result , guard ) else None
1153
1162
raise NotImplementedError (f"match not implemented for { type (pattern ).__name__ } " )
1154
1163
1155
1164
@@ -1273,7 +1282,7 @@ def eval_exp(env: Env, exp: Object) -> Object:
1273
1282
elif isinstance (callee .func , MatchFunction ):
1274
1283
arg = eval_exp (env , exp .arg )
1275
1284
for case in callee .func .cases :
1276
- m = match (arg , case .pattern )
1285
+ m = match (arg , case .pattern , env , case . guard )
1277
1286
if m is None :
1278
1287
continue
1279
1288
return eval_exp ({** callee .env , ** m }, case .body )
@@ -3256,6 +3265,98 @@ def test_match_var_binds_var(self) -> None:
3256
3265
Int (3 ),
3257
3266
)
3258
3267
3268
+ def test_match_guard_closure_var (self ) -> None :
3269
+ self .assertEqual (
3270
+ self ._run (
3271
+ """
3272
+ id 1
3273
+ . id =
3274
+ | x guard cond -> "one"
3275
+ | x -> "idk"
3276
+ . cond = 2
3277
+ """
3278
+ ),
3279
+ String ("idk" ),
3280
+ )
3281
+
3282
+ def test_match_record_guard_pass (self ) -> None :
3283
+ self .assertEqual (
3284
+ self ._run (
3285
+ """
3286
+ id {cond=#true}
3287
+ . id =
3288
+ | {cond=cond} guard cond -> "yes"
3289
+ | x -> "no"
3290
+ """
3291
+ ),
3292
+ String ("yes" ),
3293
+ )
3294
+
3295
+ def test_match_record_guard_fail (self ) -> None :
3296
+ self .assertEqual (
3297
+ self ._run (
3298
+ """
3299
+ id {cond=#false}
3300
+ . id =
3301
+ | {cond=cond} guard cond -> "yes"
3302
+ | x -> "no"
3303
+ """
3304
+ ),
3305
+ String ("no" ),
3306
+ )
3307
+
3308
+ def test_match_list_guard_pass (self ) -> None :
3309
+ self .assertEqual (
3310
+ self ._run (
3311
+ """
3312
+ id [#true]
3313
+ . id =
3314
+ | [cond] guard cond -> "yes"
3315
+ | x -> "no"
3316
+ """
3317
+ ),
3318
+ String ("yes" ),
3319
+ )
3320
+
3321
+ def test_match_list_guard_fail (self ) -> None :
3322
+ self .assertEqual (
3323
+ self ._run (
3324
+ """
3325
+ id [#false]
3326
+ . id =
3327
+ | [cond] guard cond -> "yes"
3328
+ | x -> "no"
3329
+ """
3330
+ ),
3331
+ String ("no" ),
3332
+ )
3333
+
3334
+ def test_match_guard_pass (self ) -> None :
3335
+ self .assertEqual (
3336
+ self ._run (
3337
+ """
3338
+ id 1
3339
+ . id =
3340
+ | x guard x==1 -> "one"
3341
+ | x -> "idk"
3342
+ """
3343
+ ),
3344
+ String ("one" ),
3345
+ )
3346
+
3347
+ def test_match_guard_fail (self ) -> None :
3348
+ self .assertEqual (
3349
+ self ._run (
3350
+ """
3351
+ id 2
3352
+ . id =
3353
+ | x guard x==1 -> "one"
3354
+ | x -> "idk"
3355
+ """
3356
+ ),
3357
+ String ("idk" ),
3358
+ )
3359
+
3259
3360
def test_match_var_binds_first_arm (self ) -> None :
3260
3361
self .assertEqual (
3261
3362
self ._run (
0 commit comments