Skip to content

Commit 849e403

Browse files
sharadmvtensorflower-gardener
authored andcommitted
[Oryx] Generalize tuple pattern matching to sequences
In addition to enabling pattern matching for lists and other sequences, this enables using pattern combinators for strings to do regex-like matching: ``` pattern = ['a', 'b', Star('c'), 'd'] assert is_match(pattern, 'abcccccd') ``` PiperOrigin-RevId: 377157879
1 parent c716e6f commit 849e403

File tree

2 files changed

+105
-45
lines changed

2 files changed

+105
-45
lines changed

spinoffs/oryx/oryx/experimental/matching/matcher.py

Lines changed: 77 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
match(1, 2) # ==> MatchError!
5959
```
6060
61-
A `tuple` pattern matches a `tuple` expression if each of its elements matches.
61+
A `Sequence` pattern matches a `Sequence` expression if each of its elements
62+
matches.
6263
Similarly, a `dict` pattern matches a `dict` expression if their keys are
6364
the same and the corresponding values match.
6465
@@ -127,10 +128,11 @@
127128
128129
## `Star`, `Plus`, and `Segment`
129130
130-
The `Star`, `Plus` and `Segment` patterns are used when matching tuples.
131-
Specifically, they can match variable-length slices of a tuple. `Star` takes
132-
in a pattern and matches a tuple slice if all of its elements match the pattern.
133-
`Star` must be used inside of a tuple.
131+
The `Star`, `Plus` and `Segment` patterns are used when matching sequences.
132+
Specifically, they can match variable-length slices of a sequence. `Star` takes
133+
in a pattern and matches a sequence slice if all of its elements match the
134+
pattern.
135+
`Star` must be used inside of a sequence.
134136
135137
```python
136138
match((Star(1),), ()) # ==> {}
@@ -161,7 +163,8 @@
161163
If the pattern inside of a `Star` has any names in it (from `Var`s or
162164
nested `Star`s), the same value needs to match for the entire slice. If name
163165
`'x'` needs to be bound to `1` to make an element of a slice match, it needs to
164-
be `1` for every subsequent element of the tuple. For example, in the following
166+
be `1` for every subsequent element of the sequence. For example, in the
167+
following
165168
snippets, `x` cannot be bound to multiple values to make the match succeed.
166169
167170
```python
@@ -176,7 +179,7 @@
176179
this means that the value bound to its name must match across the slice.
177180
When we accumulate for a name inside of a `Star`, rather than enforcing that the
178181
match is the same across the slice, we collect the individual matches into a
179-
tuple and bind the name to the tuple.
182+
sequence and bind the name to the sequence.
180183
```python
181184
match((Star(Var('x'), accumulate=['x']),), (1, 2, 3)) # ==> {'x': (1, 2, 3)}
182185
match((Star((Var('x'), Var('y')), accumulate=['y']),), ((1, 2), (1, 3)))
@@ -196,7 +199,7 @@
196199
```
197200
198201
A `Segment` is shorthand for a named `Star` that has the `Dot` pattern,
199-
meaning it matches slices of a tuple regardless of the individual values.
202+
meaning it matches slices of a sequence regardless of the individual values.
200203
Specifically, `Segment(name)` matches the same expressions as
201204
`Star(Dot, name=name)`.
202205
```python
@@ -220,7 +223,7 @@
220223
expression are equal according to Python equality. To extend it for a custom
221224
type `Foo` that we'd like to match against, we can call
222225
`matcher.register(Foo)(custom_matcher)` to define a custom matcher for `Foo`
223-
objects. This is how we define matchers for tuples and dictionaries.
226+
objects. This is how we define matchers for sequences and dictionaries.
224227
225228
2. Alternatively, we provide the `Pattern` class which is the parent class
226229
of the various combinators (`Choice`, `Not`, `Star`, etc.). By subclassing
@@ -229,7 +232,7 @@
229232
"""
230233
import functools
231234

232-
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, TypeVar
235+
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, TypeVar
233236

234237
import dataclasses
235238

@@ -254,8 +257,8 @@
254257
Bindings = Dict[str, Expr]
255258
Success = Iterator[T]
256259
Continuation = Callable[[Bindings], Success]
257-
TupleSegment = Tuple[Any]
258-
StarContinuation = Callable[[Bindings, TupleSegment], Success]
260+
SequenceSegment = Sequence[Any]
261+
StarContinuation = Callable[[Bindings, SequenceSegment], Success]
259262
Matcher = Callable[[Expr, Bindings, Continuation], Success]
260263

261264
id_success = lambda x: (yield x)
@@ -297,7 +300,7 @@ class MatchError(Exception):
297300
def is_match(pattern: Any, expr: Expr) -> bool:
298301
"""Returns whether or not an expression matches a pattern."""
299302
if isinstance(pattern, Star):
300-
raise ValueError('`Star` pattern must be inside of a tuple.')
303+
raise ValueError('`Star` pattern must be inside of a sequence.')
301304
for _ in matcher(pattern)(expr, {}, id_success):
302305
return True
303306
return False
@@ -306,7 +309,7 @@ def is_match(pattern: Any, expr: Expr) -> bool:
306309
def match(pattern: Any, expr: Expr) -> Bindings:
307310
"""Returns a single match for pattern and expression or errors otherwise."""
308311
if isinstance(pattern, Star):
309-
raise ValueError('`Star` pattern must be inside of a tuple.')
312+
raise ValueError('`Star` pattern must be inside of a sequence.')
310313
for bindings in matcher(pattern)(expr, {}, id_success):
311314
return bindings
312315
raise MatchError(f'No match found. Pattern: {pattern}, Expression: {expr}')
@@ -315,7 +318,7 @@ def match(pattern: Any, expr: Expr) -> Bindings:
315318
def match_all(pattern: Any, expr: Expr) -> Iterator[Bindings]:
316319
"""Returns an iterator over all bindings matching a pattern to an expression."""
317320
if isinstance(pattern, Star):
318-
raise ValueError('`Star` pattern must be inside of a tuple.')
321+
raise ValueError('`Star` pattern must be inside of a sequence.')
319322
yield from matcher(pattern)(expr, {}, id_success)
320323

321324

@@ -422,14 +425,14 @@ def __str__(self):
422425

423426
@dataclasses.dataclass(frozen=True)
424427
class Star(Pattern):
425-
"""A pattern for repeated sub-patterns inside of a tuple.
428+
"""A pattern for repeated sub-patterns inside of a sequence.
426429
427430
Attributes:
428-
pattern: an object that will be matched against elements of a tuple.
431+
pattern: an object that will be matched against elements of a sequence.
429432
name: an optional `str` name to bind the result of the star match.
430433
accumulate: a sequence of `str` names corresponding to `Var`s in `pattern`
431-
that will be accumulated into a tuple instead of having to match across
432-
the elements of the tuple.
434+
that will be accumulated into a sequence instead of having to match across
435+
the elements of the sequence.
433436
greedy: a `bool` that sets whether or not the `Star` greedily matches a
434437
sequence. A greedy `Star` will try to match slices starting from the
435438
largest possible and then trying smaller ones. A non-greedy `Star` will
@@ -458,7 +461,7 @@ def accumulate_value(self, bindings: Bindings, name: str,
458461

459462
def accumulate_match(self, expr: Expr, bindings: Bindings,
460463
succeed: Continuation) -> Success:
461-
"""Matches each element of a tuple to this `Star`'s pattern.
464+
"""Matches each element of a sequence to this `Star`'s pattern.
462465
463466
Iteratively matches each element of `expr` with `self.pattern`. For any
464467
created as the result of each match, they are accumulated if the names
@@ -498,7 +501,8 @@ def match(self, expr: Expr, bindings: Bindings,
498501
"""Matches the `Star` pattern against an expression.
499502
500503
Constructs all splits of the expression and performs an `accumulate_match`
501-
on each of the left halves. The right half is matched using tuple matching.
504+
on each of the left halves. The right half is matched using sequence
505+
matching.
502506
503507
Args:
504508
expr: An expression to match.
@@ -511,7 +515,7 @@ def match(self, expr: Expr, bindings: Bindings,
511515
The results of the `succeed` continuation function, augmented with
512516
bindings corresponding to matches made over the course of the Star match.
513517
"""
514-
if not isinstance(expr, tuple):
518+
if not isinstance(expr, Sequence):
515519
return
516520
# If name appears in bindings, we have already matched and need to verify
517521
# that bound value matches the current expression
@@ -545,14 +549,14 @@ def __str__(self):
545549

546550

547551
class Plus(Star):
548-
"""A pattern for repeated sub-patterns inside of a tuple.
552+
"""A pattern for repeated sub-patterns inside of a sequence.
549553
550554
Attributes:
551-
pattern: an object that will be matched against elements of a tuple.
555+
pattern: an object that will be matched against elements of a sequence.
552556
name: an optional `str` name to bind the result of the star match.
553557
accumulate: a sequence of `str` names corresponding to `Var`s in `pattern`
554-
that will be accumulated into a tuple instead of having to match across
555-
the elements of the tuple.
558+
that will be accumulated into a sequence instead of having to match across
559+
the elements of the sequence.
556560
greedy: a `bool` that sets whether or not the `Plus` greedily matches a
557561
sequence. A greedy `Plus` will try to match slices starting from the
558562
largest possible and then trying smaller ones. A non-greedy `Plus` will
@@ -570,14 +574,14 @@ def __init__(self,
570574

571575

572576
class Segment(Star):
573-
"""Matches any slice of a tuple.
577+
"""Matches any slice of a sequence.
574578
575579
Attributes:
576580
name: a `str` name to bind the result of the segment match. If `name` is
577581
`None`, a match produces no binding.
578582
accumulate: a sequence of `str` names corresponding to `Var`s in `pattern`
579-
that will be accumulated into a tuple instead of having to match across
580-
the elements of the tuple.
583+
that will be accumulated into a sequence instead of having to match across
584+
the elements of the sequence.
581585
greedy: a `bool` that sets whether or not the `Segment` greedily matches a
582586
sequence. A greedy `Segment` will try to match slices starting from the
583587
largest possible and then trying smaller ones. A non-greedy `Segment` will
@@ -597,39 +601,40 @@ def __init__(self,
597601
Dot, name=name, accumulate=accumulate, plus=plus, greedy=greedy)
598602

599603

600-
@matcher.register(tuple)
601-
def tuple_matcher(pattern: Tuple[Any]):
602-
"""Returns a matcher for a given tuple pattern."""
604+
@matcher.register(Sequence)
605+
def sequence_matcher(pattern: Sequence[Any]):
606+
"""Returns a matcher for a given sequence pattern."""
603607

604-
def tuple_match(expr: Expr, bindings: Bindings,
605-
succeed: Continuation) -> Success:
606-
"""Matches a tuple expression against a tuple pattern.
608+
def sequence_match(expr: Expr, bindings: Bindings,
609+
succeed: Continuation) -> Success:
610+
"""Matches a sequence expression against a sequence pattern.
607611
608-
Matches each element of the tuple pattern against each element of the
609-
tuple expression. When there is a `Star` in the tuple pattern, the tuple
610-
matcher calls the `Star` pattern's matcher, which calls a special success
611-
function that takes in the remaining part of the tuple to match.
612+
Matches each element of the sequence pattern against each element of the
613+
sequence expression. When there is a `Star` in the sequence pattern, the
614+
sequence matcher calls the `Star` pattern's matcher, which calls a special
615+
success function that takes in the remaining part of the sequence to match.
612616
613617
Args:
614-
expr: An expression to match.
618+
expr: A sequence to match.
615619
bindings: A dictionary mapping string names to values representing the
616620
results of previous matches.
617621
succeed: A function that when passed in `bindings` returns a generator
618622
over results.
619623
620624
Yields:
621625
The results of the `succeed` continuation function, augmented with
622-
bindings corresponding to matches made over the course of the tuple match.
626+
bindings corresponding to matches made over the course of the sequence
627+
match.
623628
"""
624-
if not isinstance(expr, tuple):
629+
if not isinstance(expr, Sequence):
625630
return
626631
# Special case Star here
627632
if pattern and isinstance(pattern[0], Star):
628633
star_pattern, rest_pattern = pattern[0], pattern[1:]
629634

630635
# A star continuation takes in an additional `remaining` argument that
631-
# contains the remaining slice of the tuple to be matched.
632-
def star_succeed(bindings: Bindings, remaining: Tuple[Any]) -> Success:
636+
# contains the remaining slice of the sequence to be matched.
637+
def star_succeed(bindings: Bindings, remaining: Sequence[Any]) -> Success:
633638
post_star_match = matcher(rest_pattern)
634639
yield from post_star_match(remaining, bindings, succeed)
635640

@@ -654,7 +659,34 @@ def rest_succeed(bindings):
654659
first_match = matcher(pattern[0])
655660
yield from first_match(expr[0], bindings, rest_succeed)
656661

657-
return tuple_match
662+
return sequence_match
663+
664+
665+
@matcher.register(str)
666+
def str_matcher(pattern: Any):
667+
"""Overrides default sequence matcher for strings to avoid infinite recursion.
668+
669+
Strings are a tricky case of sequence because indexing into a string returns
670+
a length-1 string. This, by default, triggers an infinite recursion in the
671+
sequence matcher. To avoid this, we special-case 1-length strings to do a
672+
manual match and use the sequence matcher for other strings.
673+
674+
Args:
675+
pattern: A pattern used to match a string.
676+
Returns:
677+
A pattern matcher for string expressions.
678+
"""
679+
def str_match(expr: Expr,
680+
bindings: Bindings,
681+
succeed: Continuation) -> Success:
682+
if not isinstance(expr, str):
683+
return
684+
if len(expr) == 1 and isinstance(pattern, str):
685+
if pattern == expr:
686+
yield from succeed(bindings)
687+
return
688+
yield from sequence_matcher(pattern)(expr, bindings, succeed)
689+
return str_match
658690

659691

660692
@matcher.register(dict)

spinoffs/oryx/oryx/experimental/matching/matcher_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def test_match_all_errors_with_star_pattern(self):
5252
with self.assertRaises(ValueError):
5353
list(matcher.match_all(matcher.Star(1), 1.))
5454

55+
def test_list_patterns_match_equal_lists(self):
56+
self.assertDictEqual(matcher.match([1, 2, 3], [1, 2, 3]), {})
57+
self.assertDictEqual(matcher.match([(1, 2), 2, 3], [(1, 2), 2, 3]), {})
58+
5559
def test_tuple_patterns_match_equal_tuples(self):
5660
self.assertDictEqual(matcher.match((1, 2, 3), (1, 2, 3)), {})
5761
self.assertDictEqual(matcher.match(((1, 2), 2, 3), ((1, 2), 2, 3)), {})
@@ -225,6 +229,30 @@ def test_segment_must_be_the_same_when_given_same_name(self):
225229
for i in range(len(matches)):
226230
self.assertDictEqual(matches[i], dict(x=(1,) * i, y=(1,) * (10 - 2 * i)))
227231

232+
def test_can_match_string_literals(self):
233+
pattern = 'abcd'
234+
self.assertDictEqual(matcher.match(pattern, 'abcd'), {})
235+
with self.assertRaises(matcher.MatchError):
236+
matcher.match(pattern, 'dcba')
237+
238+
def test_can_match_var_with_length_one_string(self):
239+
pattern = matcher.Var('x')
240+
self.assertDictEqual(matcher.match(pattern, 'a'), {'x': 'a'})
241+
242+
def test_can_use_star_patterns_in_string_patterns(self):
243+
pattern = ['a', 'b', matcher.Segment('rest')]
244+
self.assertDictEqual(matcher.match(pattern, 'abcd'), {'rest': 'cd'})
245+
246+
with self.assertRaises(matcher.MatchError):
247+
matcher.match(pattern, 'acd')
248+
249+
pattern = ['a', 'b', matcher.Star('c'), 'd']
250+
self.assertDictEqual(matcher.match(pattern, 'abccccd'), {})
251+
self.assertDictEqual(matcher.match(pattern, 'abd'), {})
252+
253+
with self.assertRaises(matcher.MatchError):
254+
matcher.match(pattern, 'abccc')
255+
228256

229257
if __name__ == '__main__':
230258
absltest.main()

0 commit comments

Comments
 (0)