Skip to content

Commit 69b5fa1

Browse files
committed
Support record pattern matching
1 parent 17cb498 commit 69b5fa1

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

ir.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def to_string(self, gvn: InstrId) -> str:
8787
return f"{type(self).__name__}<{self.value}>"
8888

8989

90+
@dataclasses.dataclass(eq=False)
91+
class CConst(Instr):
92+
type: str
93+
value: str
94+
95+
def to_string(self, gvn: InstrId) -> str:
96+
return f"{type(self).__name__}<{self.type}; {self.value}>"
97+
98+
9099
@dataclasses.dataclass(eq=False)
91100
class Param(Instr):
92101
idx: int
@@ -130,6 +139,12 @@ class IntLess(HasOperands):
130139
pass
131140

132141

142+
# TODO(max): Maybe start work on boxing/unboxing in the IR.
143+
@dataclasses.dataclass(init=False, eq=False)
144+
class CEqual(HasOperands):
145+
pass
146+
147+
133148
@dataclasses.dataclass(init=False, eq=False)
134149
class RefineType(HasOperands):
135150
def __init__(self, value: Instr, ty: ConstantLattice) -> None:
@@ -230,6 +245,11 @@ class NewRecord(Instr):
230245
num_fields: int
231246

232247

248+
@dataclasses.dataclass(init=False, eq=False)
249+
class IsRecord(HasOperands):
250+
pass
251+
252+
233253
@dataclasses.dataclass(eq=False)
234254
class RecordSet(HasOperands):
235255
idx: int
@@ -258,6 +278,11 @@ def to_string(self, gvn: InstrId) -> str:
258278
return stem + ", ".join(f"{gvn.name(op)}" for op in self.operands)
259279

260280

281+
@dataclasses.dataclass(init=False, eq=False)
282+
class RecordNumFields(HasOperands):
283+
pass
284+
285+
261286
Env = Dict[str, Instr]
262287

263288

@@ -484,6 +509,14 @@ def op(idx: int) -> str:
484509
return _decl("bool", f"is_list({op(0)})")
485510
if isinstance(instr, IsEmptyList):
486511
return _decl("bool", f"{op(0)} == empty_list()")
512+
if isinstance(instr, IsRecord):
513+
return _decl("bool", f"is_record({op(0)})")
514+
if isinstance(instr, RecordNumFields):
515+
return _decl("uword", f"record_num_fields({op(0)})")
516+
if isinstance(instr, CConst):
517+
return _decl(instr.type, instr.value)
518+
if isinstance(instr, CEqual):
519+
return _decl("bool", f"{op(0)} == {op(1)}")
487520
if isinstance(instr, Return):
488521
return f"return {op(0)};\n"
489522
if isinstance(instr, Jump):
@@ -575,6 +608,35 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success
575608
is_empty = self.emit(IsEmptyList(the_list))
576609
self.emit(CondBranch(is_empty, success, fallthrough))
577610
return updates
611+
if isinstance(pattern, Record):
612+
is_record = self.emit(IsRecord(param))
613+
updates = {}
614+
is_record_block = self.fn.cfg.new_block()
615+
self.emit(CondBranch(is_record, is_record_block, fallthrough))
616+
self.block = is_record_block
617+
for key, pattern_value in pattern.data.items():
618+
if isinstance(pattern_value, Spread):
619+
if pattern_value.name:
620+
raise NotImplementedError("named record spread not yet supported")
621+
self.emit(Jump(success))
622+
return updates
623+
key_idx = self.record_key(key)
624+
record_value = self.emit(RecordGet(param, key_idx))
625+
is_null = self.emit(CEqual(record_value, self.emit(CConst("struct object*", "NULL"))))
626+
recursive_block = self.fn.cfg.new_block()
627+
self.emit(CondBranch(is_null, fallthrough, recursive_block))
628+
self.block = recursive_block
629+
pattern_success = self.fn.cfg.new_block()
630+
# Recursive pattern match
631+
updates.update(
632+
self.compile_match_pattern(env, record_value, pattern_value, pattern_success, fallthrough)
633+
)
634+
self.block = pattern_success
635+
# Too many fields
636+
num_fields = self.emit(RecordNumFields(param))
637+
cmp = self.emit(CEqual(num_fields, self.emit(CConst("uword", str(len(pattern.data))))))
638+
self.emit(CondBranch(cmp, success, fallthrough))
639+
return updates
578640
raise NotImplementedError(f"pattern {type(pattern)} {pattern}")
579641

580642
def compile_body(self, env: Env, exp: Object) -> None:
@@ -855,6 +917,14 @@ def run(self) -> dict[Instr, ConstantLattice]:
855917
new_type = CTop()
856918
elif isinstance(instr, ListRest):
857919
new_type = CTop()
920+
elif isinstance(instr, IsRecord):
921+
new_type = CTop()
922+
elif isinstance(instr, CConst):
923+
new_type = CTop()
924+
elif isinstance(instr, CEqual):
925+
new_type = CTop()
926+
elif isinstance(instr, RecordNumFields):
927+
new_type = CTop()
858928
else:
859929
raise NotImplementedError(f"SCCP {instr}")
860930
old_type = self.type_of(instr)
@@ -1488,6 +1558,194 @@ def test_match_list_spread(self) -> None:
14881558
}""",
14891559
)
14901560

1561+
def test_match_empty_record(self) -> None:
1562+
compiler = Compiler()
1563+
compiler.compile_body({}, _parse("| {} -> 1"))
1564+
self.assertEqual(
1565+
compiler.fns[1].to_string(InstrId()),
1566+
"""\
1567+
fn1 {
1568+
bb0 {
1569+
v0 = Param<0; $clo>
1570+
v1 = Param<1; arg_0>
1571+
Jump bb2
1572+
}
1573+
bb2 {
1574+
v2 = IsRecord v1
1575+
CondBranch v2, bb4, bb1
1576+
}
1577+
bb4 {
1578+
v3 = RecordNumFields v1
1579+
v4 = CConst<uword; 0>
1580+
v5 = CEqual v3, v4
1581+
CondBranch v5, bb3, bb1
1582+
}
1583+
bb1 {
1584+
MatchFail
1585+
}
1586+
bb3 {
1587+
v6 = Const<1>
1588+
Return v6
1589+
}
1590+
}""",
1591+
)
1592+
1593+
def test_match_one_item_record(self) -> None:
1594+
compiler = Compiler()
1595+
compiler.compile_body({}, _parse("| {a=1} -> 1"))
1596+
self.assertEqual(
1597+
compiler.fns[1].to_string(InstrId()),
1598+
"""\
1599+
fn1 {
1600+
bb0 {
1601+
v0 = Param<0; $clo>
1602+
v1 = Param<1; arg_0>
1603+
Jump bb2
1604+
}
1605+
bb2 {
1606+
v2 = IsRecord v1
1607+
CondBranch v2, bb4, bb1
1608+
}
1609+
bb4 {
1610+
v3 = RecordGet<Record_a> v1
1611+
v4 = CConst<struct object*; NULL>
1612+
v5 = CEqual v3, v4
1613+
CondBranch v5, bb1, bb5
1614+
}
1615+
bb5 {
1616+
v6 = IsIntEqualWord v3, 1
1617+
CondBranch v6, bb6, bb1
1618+
}
1619+
bb6 {
1620+
v7 = RecordNumFields v1
1621+
v8 = CConst<uword; 1>
1622+
v9 = CEqual v7, v8
1623+
CondBranch v9, bb3, bb1
1624+
}
1625+
bb3 {
1626+
v10 = Const<1>
1627+
Return v10
1628+
}
1629+
bb1 {
1630+
MatchFail
1631+
}
1632+
}""",
1633+
)
1634+
1635+
def test_match_two_item_record(self) -> None:
1636+
compiler = Compiler()
1637+
compiler.compile_body({}, _parse("| {a=1, b=2} -> 3"))
1638+
self.assertEqual(
1639+
compiler.fns[1].to_string(InstrId()),
1640+
"""\
1641+
fn1 {
1642+
bb0 {
1643+
v0 = Param<0; $clo>
1644+
v1 = Param<1; arg_0>
1645+
Jump bb2
1646+
}
1647+
bb2 {
1648+
v2 = IsRecord v1
1649+
CondBranch v2, bb4, bb1
1650+
}
1651+
bb4 {
1652+
v3 = RecordGet<Record_a> v1
1653+
v4 = CConst<struct object*; NULL>
1654+
v5 = CEqual v3, v4
1655+
CondBranch v5, bb1, bb5
1656+
}
1657+
bb5 {
1658+
v6 = IsIntEqualWord v3, 1
1659+
CondBranch v6, bb6, bb1
1660+
}
1661+
bb6 {
1662+
v7 = RecordGet<Record_b> v1
1663+
v8 = CConst<struct object*; NULL>
1664+
v9 = CEqual v7, v8
1665+
CondBranch v9, bb1, bb7
1666+
}
1667+
bb7 {
1668+
v10 = IsIntEqualWord v7, 2
1669+
CondBranch v10, bb8, bb1
1670+
}
1671+
bb8 {
1672+
v11 = RecordNumFields v1
1673+
v12 = CConst<uword; 2>
1674+
v13 = CEqual v11, v12
1675+
CondBranch v13, bb3, bb1
1676+
}
1677+
bb3 {
1678+
v14 = Const<3>
1679+
Return v14
1680+
}
1681+
bb1 {
1682+
MatchFail
1683+
}
1684+
}""",
1685+
)
1686+
1687+
def test_match_record_spread(self) -> None:
1688+
compiler = Compiler()
1689+
compiler.compile_body({}, _parse("| {a=a, ...} -> a"))
1690+
self.assertEqual(
1691+
compiler.fns[1].to_string(InstrId()),
1692+
"""\
1693+
fn1 {
1694+
bb0 {
1695+
v0 = Param<0; $clo>
1696+
v1 = Param<1; arg_0>
1697+
Jump bb2
1698+
}
1699+
bb2 {
1700+
v2 = IsRecord v1
1701+
CondBranch v2, bb4, bb1
1702+
}
1703+
bb4 {
1704+
v3 = RecordGet<Record_a> v1
1705+
v4 = CConst<struct object*; NULL>
1706+
v5 = CEqual v3, v4
1707+
CondBranch v5, bb1, bb5
1708+
}
1709+
bb5 {
1710+
Jump bb6
1711+
}
1712+
bb6 {
1713+
Jump bb3
1714+
}
1715+
bb3 {
1716+
Return v3
1717+
}
1718+
bb1 {
1719+
MatchFail
1720+
}
1721+
}""",
1722+
)
1723+
CleanCFG(compiler.fns[1]).run()
1724+
self.assertEqual(
1725+
compiler.fns[1].to_string(InstrId()),
1726+
"""\
1727+
fn1 {
1728+
bb0 {
1729+
v0 = Param<0; $clo>
1730+
v1 = Param<1; arg_0>
1731+
v2 = IsRecord v1
1732+
CondBranch v2, bb4, bb1
1733+
}
1734+
bb4 {
1735+
v3 = RecordGet<Record_a> v1
1736+
v4 = CConst<struct object*; NULL>
1737+
v5 = CEqual v3, v4
1738+
CondBranch v5, bb1, bb5
1739+
}
1740+
bb5 {
1741+
Return v3
1742+
}
1743+
bb1 {
1744+
MatchFail
1745+
}
1746+
}""",
1747+
)
1748+
14911749
def test_apply_fn(self) -> None:
14921750
compiler = Compiler()
14931751
compiler.compile_body({}, _parse("f 1 . f = x -> x + 1"))
@@ -2045,6 +2303,15 @@ def test_record_access(self) -> None:
20452303
def test_record_builder_access(self) -> None:
20462304
self.assertEqual(_run("(f 1 2)@a . f = x -> y -> {a = x, b = y}"), "1\n")
20472305

2306+
def test_match_record(self) -> None:
2307+
self.assertEqual(_run("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6"), "6\n")
2308+
2309+
def test_match_record_too_few_keys(self) -> None:
2310+
self.assertEqual(_run("f {a = 4, b = 5} . f = | {a = _} -> 3 | {a = _, b = _} -> 6"), "6\n")
2311+
2312+
def test_match_record_spread(self) -> None:
2313+
self.assertEqual(_run("f {a=1, b=2, c=3} . f = | {a=a, ...} -> a"), "1\n")
2314+
20482315
def test_hole(self) -> None:
20492316
self.assertEqual(_run("()"), "()\n")
20502317

0 commit comments

Comments
 (0)