@@ -87,6 +87,15 @@ def to_string(self, gvn: InstrId) -> str:
8787 return f"{ type (self ).__name__ } <{ self .value } >"
8888
8989
90+ @dataclasses .dataclass (eq = False )
91+ class CConst (Instr ):
92+ type : str
93+ value : str
94+
95+ def to_string (self , gvn : InstrId ) -> str :
96+ return f"{ type (self ).__name__ } <{ self .type } ; { self .value } >"
97+
98+
9099@dataclasses .dataclass (eq = False )
91100class Param (Instr ):
92101 idx : int
@@ -130,6 +139,12 @@ class IntLess(HasOperands):
130139 pass
131140
132141
142+ # TODO(max): Maybe start work on boxing/unboxing in the IR.
143+ @dataclasses .dataclass (init = False , eq = False )
144+ class CEqual (HasOperands ):
145+ pass
146+
147+
133148@dataclasses .dataclass (init = False , eq = False )
134149class RefineType (HasOperands ):
135150 def __init__ (self , value : Instr , ty : ConstantLattice ) -> None :
@@ -230,6 +245,11 @@ class NewRecord(Instr):
230245 num_fields : int
231246
232247
248+ @dataclasses .dataclass (init = False , eq = False )
249+ class IsRecord (HasOperands ):
250+ pass
251+
252+
233253@dataclasses .dataclass (eq = False )
234254class RecordSet (HasOperands ):
235255 idx : int
@@ -258,6 +278,11 @@ def to_string(self, gvn: InstrId) -> str:
258278 return stem + ", " .join (f"{ gvn .name (op )} " for op in self .operands )
259279
260280
281+ @dataclasses .dataclass (init = False , eq = False )
282+ class RecordNumFields (HasOperands ):
283+ pass
284+
285+
261286Env = Dict [str , Instr ]
262287
263288
@@ -484,6 +509,14 @@ def op(idx: int) -> str:
484509 return _decl ("bool" , f"is_list({ op (0 )} )" )
485510 if isinstance (instr , IsEmptyList ):
486511 return _decl ("bool" , f"{ op (0 )} == empty_list()" )
512+ if isinstance (instr , IsRecord ):
513+ return _decl ("bool" , f"is_record({ op (0 )} )" )
514+ if isinstance (instr , RecordNumFields ):
515+ return _decl ("uword" , f"record_num_fields({ op (0 )} )" )
516+ if isinstance (instr , CConst ):
517+ return _decl (instr .type , instr .value )
518+ if isinstance (instr , CEqual ):
519+ return _decl ("bool" , f"{ op (0 )} == { op (1 )} " )
487520 if isinstance (instr , Return ):
488521 return f"return { op (0 )} ;\n "
489522 if isinstance (instr , Jump ):
@@ -575,6 +608,35 @@ def compile_match_pattern(self, env: Env, param: Instr, pattern: Object, success
575608 is_empty = self .emit (IsEmptyList (the_list ))
576609 self .emit (CondBranch (is_empty , success , fallthrough ))
577610 return updates
611+ if isinstance (pattern , Record ):
612+ is_record = self .emit (IsRecord (param ))
613+ updates = {}
614+ is_record_block = self .fn .cfg .new_block ()
615+ self .emit (CondBranch (is_record , is_record_block , fallthrough ))
616+ self .block = is_record_block
617+ for key , pattern_value in pattern .data .items ():
618+ if isinstance (pattern_value , Spread ):
619+ if pattern_value .name :
620+ raise NotImplementedError ("named record spread not yet supported" )
621+ self .emit (Jump (success ))
622+ return updates
623+ key_idx = self .record_key (key )
624+ record_value = self .emit (RecordGet (param , key_idx ))
625+ is_null = self .emit (CEqual (record_value , self .emit (CConst ("struct object*" , "NULL" ))))
626+ recursive_block = self .fn .cfg .new_block ()
627+ self .emit (CondBranch (is_null , fallthrough , recursive_block ))
628+ self .block = recursive_block
629+ pattern_success = self .fn .cfg .new_block ()
630+ # Recursive pattern match
631+ updates .update (
632+ self .compile_match_pattern (env , record_value , pattern_value , pattern_success , fallthrough )
633+ )
634+ self .block = pattern_success
635+ # Too many fields
636+ num_fields = self .emit (RecordNumFields (param ))
637+ cmp = self .emit (CEqual (num_fields , self .emit (CConst ("uword" , str (len (pattern .data ))))))
638+ self .emit (CondBranch (cmp , success , fallthrough ))
639+ return updates
578640 raise NotImplementedError (f"pattern { type (pattern )} { pattern } " )
579641
580642 def compile_body (self , env : Env , exp : Object ) -> None :
@@ -855,6 +917,14 @@ def run(self) -> dict[Instr, ConstantLattice]:
855917 new_type = CTop ()
856918 elif isinstance (instr , ListRest ):
857919 new_type = CTop ()
920+ elif isinstance (instr , IsRecord ):
921+ new_type = CTop ()
922+ elif isinstance (instr , CConst ):
923+ new_type = CTop ()
924+ elif isinstance (instr , CEqual ):
925+ new_type = CTop ()
926+ elif isinstance (instr , RecordNumFields ):
927+ new_type = CTop ()
858928 else :
859929 raise NotImplementedError (f"SCCP { instr } " )
860930 old_type = self .type_of (instr )
@@ -1488,6 +1558,194 @@ def test_match_list_spread(self) -> None:
14881558}""" ,
14891559 )
14901560
1561+ def test_match_empty_record (self ) -> None :
1562+ compiler = Compiler ()
1563+ compiler .compile_body ({}, _parse ("| {} -> 1" ))
1564+ self .assertEqual (
1565+ compiler .fns [1 ].to_string (InstrId ()),
1566+ """\
1567+ fn1 {
1568+ bb0 {
1569+ v0 = Param<0; $clo>
1570+ v1 = Param<1; arg_0>
1571+ Jump bb2
1572+ }
1573+ bb2 {
1574+ v2 = IsRecord v1
1575+ CondBranch v2, bb4, bb1
1576+ }
1577+ bb4 {
1578+ v3 = RecordNumFields v1
1579+ v4 = CConst<uword; 0>
1580+ v5 = CEqual v3, v4
1581+ CondBranch v5, bb3, bb1
1582+ }
1583+ bb1 {
1584+ MatchFail
1585+ }
1586+ bb3 {
1587+ v6 = Const<1>
1588+ Return v6
1589+ }
1590+ }""" ,
1591+ )
1592+
1593+ def test_match_one_item_record (self ) -> None :
1594+ compiler = Compiler ()
1595+ compiler .compile_body ({}, _parse ("| {a=1} -> 1" ))
1596+ self .assertEqual (
1597+ compiler .fns [1 ].to_string (InstrId ()),
1598+ """\
1599+ fn1 {
1600+ bb0 {
1601+ v0 = Param<0; $clo>
1602+ v1 = Param<1; arg_0>
1603+ Jump bb2
1604+ }
1605+ bb2 {
1606+ v2 = IsRecord v1
1607+ CondBranch v2, bb4, bb1
1608+ }
1609+ bb4 {
1610+ v3 = RecordGet<Record_a> v1
1611+ v4 = CConst<struct object*; NULL>
1612+ v5 = CEqual v3, v4
1613+ CondBranch v5, bb1, bb5
1614+ }
1615+ bb5 {
1616+ v6 = IsIntEqualWord v3, 1
1617+ CondBranch v6, bb6, bb1
1618+ }
1619+ bb6 {
1620+ v7 = RecordNumFields v1
1621+ v8 = CConst<uword; 1>
1622+ v9 = CEqual v7, v8
1623+ CondBranch v9, bb3, bb1
1624+ }
1625+ bb3 {
1626+ v10 = Const<1>
1627+ Return v10
1628+ }
1629+ bb1 {
1630+ MatchFail
1631+ }
1632+ }""" ,
1633+ )
1634+
1635+ def test_match_two_item_record (self ) -> None :
1636+ compiler = Compiler ()
1637+ compiler .compile_body ({}, _parse ("| {a=1, b=2} -> 3" ))
1638+ self .assertEqual (
1639+ compiler .fns [1 ].to_string (InstrId ()),
1640+ """\
1641+ fn1 {
1642+ bb0 {
1643+ v0 = Param<0; $clo>
1644+ v1 = Param<1; arg_0>
1645+ Jump bb2
1646+ }
1647+ bb2 {
1648+ v2 = IsRecord v1
1649+ CondBranch v2, bb4, bb1
1650+ }
1651+ bb4 {
1652+ v3 = RecordGet<Record_a> v1
1653+ v4 = CConst<struct object*; NULL>
1654+ v5 = CEqual v3, v4
1655+ CondBranch v5, bb1, bb5
1656+ }
1657+ bb5 {
1658+ v6 = IsIntEqualWord v3, 1
1659+ CondBranch v6, bb6, bb1
1660+ }
1661+ bb6 {
1662+ v7 = RecordGet<Record_b> v1
1663+ v8 = CConst<struct object*; NULL>
1664+ v9 = CEqual v7, v8
1665+ CondBranch v9, bb1, bb7
1666+ }
1667+ bb7 {
1668+ v10 = IsIntEqualWord v7, 2
1669+ CondBranch v10, bb8, bb1
1670+ }
1671+ bb8 {
1672+ v11 = RecordNumFields v1
1673+ v12 = CConst<uword; 2>
1674+ v13 = CEqual v11, v12
1675+ CondBranch v13, bb3, bb1
1676+ }
1677+ bb3 {
1678+ v14 = Const<3>
1679+ Return v14
1680+ }
1681+ bb1 {
1682+ MatchFail
1683+ }
1684+ }""" ,
1685+ )
1686+
1687+ def test_match_record_spread (self ) -> None :
1688+ compiler = Compiler ()
1689+ compiler .compile_body ({}, _parse ("| {a=a, ...} -> a" ))
1690+ self .assertEqual (
1691+ compiler .fns [1 ].to_string (InstrId ()),
1692+ """\
1693+ fn1 {
1694+ bb0 {
1695+ v0 = Param<0; $clo>
1696+ v1 = Param<1; arg_0>
1697+ Jump bb2
1698+ }
1699+ bb2 {
1700+ v2 = IsRecord v1
1701+ CondBranch v2, bb4, bb1
1702+ }
1703+ bb4 {
1704+ v3 = RecordGet<Record_a> v1
1705+ v4 = CConst<struct object*; NULL>
1706+ v5 = CEqual v3, v4
1707+ CondBranch v5, bb1, bb5
1708+ }
1709+ bb5 {
1710+ Jump bb6
1711+ }
1712+ bb6 {
1713+ Jump bb3
1714+ }
1715+ bb3 {
1716+ Return v3
1717+ }
1718+ bb1 {
1719+ MatchFail
1720+ }
1721+ }""" ,
1722+ )
1723+ CleanCFG (compiler .fns [1 ]).run ()
1724+ self .assertEqual (
1725+ compiler .fns [1 ].to_string (InstrId ()),
1726+ """\
1727+ fn1 {
1728+ bb0 {
1729+ v0 = Param<0; $clo>
1730+ v1 = Param<1; arg_0>
1731+ v2 = IsRecord v1
1732+ CondBranch v2, bb4, bb1
1733+ }
1734+ bb4 {
1735+ v3 = RecordGet<Record_a> v1
1736+ v4 = CConst<struct object*; NULL>
1737+ v5 = CEqual v3, v4
1738+ CondBranch v5, bb1, bb5
1739+ }
1740+ bb5 {
1741+ Return v3
1742+ }
1743+ bb1 {
1744+ MatchFail
1745+ }
1746+ }""" ,
1747+ )
1748+
14911749 def test_apply_fn (self ) -> None :
14921750 compiler = Compiler ()
14931751 compiler .compile_body ({}, _parse ("f 1 . f = x -> x + 1" ))
@@ -2045,6 +2303,15 @@ def test_record_access(self) -> None:
20452303 def test_record_builder_access (self ) -> None :
20462304 self .assertEqual (_run ("(f 1 2)@a . f = x -> y -> {a = x, b = y}" ), "1\n " )
20472305
2306+ def test_match_record (self ) -> None :
2307+ self .assertEqual (_run ("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6" ), "6\n " )
2308+
2309+ def test_match_record_too_few_keys (self ) -> None :
2310+ self .assertEqual (_run ("f {a = 4, b = 5} . f = | {a = _} -> 3 | {a = _, b = _} -> 6" ), "6\n " )
2311+
2312+ def test_match_record_spread (self ) -> None :
2313+ self .assertEqual (_run ("f {a=1, b=2, c=3} . f = | {a=a, ...} -> a" ), "1\n " )
2314+
20482315 def test_hole (self ) -> None :
20492316 self .assertEqual (_run ("()" ), "()\n " )
20502317
0 commit comments