1- """
2- Domain-specific language (DSL) primitives for ARC program synthesis.
1+ """Domain-specific language (DSL) primitives for ARC program synthesis.
32
4- This module defines a small set of composable operations that act on grids.
5- Each operation is represented by an `Op` object with a name, a function, and
6- metadata about its parameters. Programs are sequences of these operations.
3+ This module defines a set of composable operations that act on grids. Each
4+ operation is represented by an :class: `Op` and registered in :data:`OPS`.
5+ Programs are sequences of these operations applied to a grid .
76"""
87
98from __future__ import annotations
109
10+ from typing import Any , Callable , Dict , List , Tuple , Optional
1111import numpy as np
12- from typing import Any , Callable , Dict , List , Tuple
1312
14- from .grid import Array , rotate90 , flip , transpose , translate , color_map , crop , pad_to , bg_color
13+ from arc_solver .grid import (
14+ Array ,
15+ rotate90 ,
16+ flip as flip_grid ,
17+ transpose as transpose_grid ,
18+ translate as translate_grid ,
19+ color_map as color_map_grid ,
20+ crop as crop_array ,
21+ pad_to ,
22+ bg_color ,
23+ )
1524
1625
1726class Op :
18- """Represents a primitive transformation on a grid.
19-
20- Attributes
21- ----------
22- name : str
23- Human-readable name of the operation.
24- fn : Callable
25- Function implementing the operation.
26- arity : int
27- Number of input grids (arity=1 for single-grid ops).
28- param_names : List[str]
29- Names of parameters accepted by the operation.
30- """
27+ """Represents a primitive transformation on a grid."""
3128
3229 def __init__ (self , name : str , fn : Callable [..., Array ], arity : int , param_names : List [str ]):
3330 self .name = name
3431 self .fn = fn
3532 self .arity = arity
3633 self .param_names = param_names
3734
38- def __call__ (self , * args , ** kwargs ) -> Array :
35+ def __call__ (self , * args : Any , ** kwargs : Any ) -> Array :
3936 return self .fn (* args , ** kwargs )
4037
4138
42- # Primitive operations (single-grid)
39+ # ---------------------------------------------------------------------------
40+ # Primitive operation implementations
41+ # ---------------------------------------------------------------------------
42+
4343def op_identity (a : Array ) -> Array :
44- return a
44+ """Return a copy of the input grid."""
45+ return a .copy ()
4546
4647
4748def op_rotate (a : Array , k : int ) -> Array :
48- return rotate90 (a , k )
49+ """Rotate grid by ``k`` quarter turns clockwise."""
50+ return rotate90 (a , - k )
4951
5052
5153def op_flip (a : Array , axis : int ) -> Array :
52- return flip (a , axis )
54+ """Flip grid along the specified axis (0=vertical, 1=horizontal)."""
55+ return flip_grid (a , axis )
5356
5457
5558def op_transpose (a : Array ) -> Array :
56- return transpose (a )
59+ """Transpose the grid."""
60+ return transpose_grid (a )
5761
5862
59- def op_translate (a : Array , dy : int , dx : int ) -> Array :
60- return translate (a , dy , dx , fill = bg_color (a ))
63+ def op_translate (a : Array , dy : int , dx : int , fill : Optional [int ] = None ) -> Array :
64+ """Translate the grid by ``(dy, dx)`` filling uncovered cells.
65+
66+ Parameters
67+ ----------
68+ a:
69+ Input grid.
70+ dy, dx:
71+ Translation offsets. Positive values move content down/right.
72+ fill:
73+ Optional fill value for uncovered cells. If ``None`` the background
74+ colour of ``a`` is used.
75+ """
76+ fill_val = 0 if fill is None else fill
77+ return translate_grid (a , dy , dx , fill = fill_val )
6178
6279
6380def op_recolor (a : Array , mapping : Dict [int , int ]) -> Array :
64- return color_map (a , mapping )
81+ """Recolour grid according to a mapping from old to new colours."""
82+ return color_map_grid (a , mapping )
6583
6684
6785def op_crop_bbox (a : Array , top : int , left : int , height : int , width : int ) -> Array :
68- # ensure cropping stays inside bounds
86+ """Crop a bounding box from the grid ensuring bounds are valid."""
6987 h , w = a .shape
7088 top = max (0 , min (top , h - 1 ))
7189 left = max (0 , min (left , w - 1 ))
7290 height = max (1 , min (height , h - top ))
7391 width = max (1 , min (width , w - left ))
74- return crop (a , top , left , height , width )
92+ return crop_array (a , top , left , height , width )
7593
7694
7795def op_pad (a : Array , out_h : int , out_w : int ) -> Array :
96+ """Pad grid to a specific height and width using background colour."""
7897 return pad_to (a , (out_h , out_w ), fill = bg_color (a ))
7998
8099
81- # Register operations in a dictionary for easy lookup
100+ # Registry of primitive operations ---------------------------------------------------------
82101OPS : Dict [str , Op ] = {
83102 "identity" : Op ("identity" , op_identity , 1 , []),
84103 "rotate" : Op ("rotate" , op_rotate , 1 , ["k" ]),
85104 "flip" : Op ("flip" , op_flip , 1 , ["axis" ]),
86105 "transpose" : Op ("transpose" , op_transpose , 1 , []),
87- "translate" : Op ("translate" , op_translate , 1 , ["dy" , "dx" ]),
106+ "translate" : Op ("translate" , op_translate , 1 , ["dy" , "dx" , "fill" ]),
88107 "recolor" : Op ("recolor" , op_recolor , 1 , ["mapping" ]),
89108 "crop" : Op ("crop" , op_crop_bbox , 1 , ["top" , "left" , "height" , "width" ]),
90109 "pad" : Op ("pad" , op_pad , 1 , ["out_h" , "out_w" ]),
91110}
92111
93112
94- def apply_program ( a : Array , program : List [ Tuple [ str , Dict [ str , Any ]]]) -> Array :
95- """Apply a sequence of operations (program) to the input array.
113+ # Semantic cache -------------------------------------------------------------------------
114+ _sem_cache : Dict [ Tuple [ bytes , str , Tuple [ Tuple [ str , Any ], ...]], Array ] = {}
96115
97- Parameters
98- ----------
99- a : Array
100- Input grid.
101- program : List of (op_name, params)
102- Sequence of operations with parameters. The operations are looked up in
103- OPS.
104-
105- Returns
106- -------
107- Array
108- Resulting grid after applying the program.
116+
117+ def apply_op (a : Array , name : str , params : Dict [str , Any ]) -> Array :
118+ """Apply a primitive operation with semantic caching."""
119+ key = (a .tobytes (), name , tuple (sorted (params .items ())))
120+ cached = _sem_cache .get (key )
121+ if cached is not None :
122+ return cached
123+ op = OPS [name ]
124+ out = op (a , ** params )
125+ _sem_cache [key ] = out
126+ return out
127+
128+
129+ # User-facing convenience wrappers --------------------------------------------------------
130+
131+ def identity (a : Array ) -> Array :
132+ """Return a copy of the input grid."""
133+ return op_identity (a )
134+
135+
136+ def rotate (a : Array , k : int ) -> Array :
137+ """Rotate grid by ``k`` quarter turns clockwise."""
138+ return op_rotate (a , k )
139+
140+
141+ def flip (a : Array , axis : int ) -> Array :
142+ """Flip grid along the specified axis."""
143+ return op_flip (a , axis )
144+
145+
146+ def transpose (a : Array ) -> Array :
147+ """Transpose the grid."""
148+ return op_transpose (a )
149+
150+
151+ def translate (a : Array , dx : int , dy : int , fill_value : Optional [int ] = None ) -> Array :
152+ """Translate grid by ``(dy, dx)`` with optional fill value."""
153+ return op_translate (a , dy , dx , fill = fill_value )
154+
155+
156+ def recolor (a : Array , color_map : Dict [int , int ]) -> Array :
157+ """Recolour grid according to a mapping."""
158+ return op_recolor (a , color_map )
159+
160+
161+ def crop (a : Array , top : int , bottom : int , left : int , right : int ) -> Array :
162+ """Crop a region specified by inclusive-exclusive bounds.
163+
164+ Args:
165+ top, bottom, left, right: Bounds following Python slice semantics where
166+ ``bottom`` and ``right`` are exclusive.
109167 """
168+ if bottom <= top or right <= left :
169+ raise ValueError ("Invalid crop bounds" )
170+ h , w = a .shape
171+ top = max (0 , min (top , h ))
172+ bottom = max (top , min (bottom , h ))
173+ left = max (0 , min (left , w ))
174+ right = max (left , min (right , w ))
175+ return a [top :bottom , left :right ].copy ()
176+
177+
178+ def pad (a : Array , top : int , bottom : int , left : int , right : int , fill_value : int = 0 ) -> Array :
179+ """Pad grid with ``fill_value`` on each side."""
180+ if min (top , bottom , left , right ) < 0 :
181+ raise ValueError ("Pad widths must be non-negative" )
182+ h , w = a .shape
183+ out = np .full ((h + top + bottom , w + left + right ), fill_value , dtype = a .dtype )
184+ out [top :top + h , left :left + w ] = a
185+ return out
186+
187+
188+ # Program application --------------------------------------------------------------------
189+
190+ def apply_program (a : Array , program : List [Tuple [str , Dict [str , Any ]]]) -> Array :
191+ """Apply a sequence of operations to the input grid."""
110192 out = a
111193 for name , params in program :
112- op = OPS [name ]
113- out = op (out , ** params )
114- return out
194+ out = apply_op (out , name , params )
195+ return out
196+
197+
198+ __all__ = [
199+ "Array" ,
200+ "Op" ,
201+ "OPS" ,
202+ "apply_program" ,
203+ "apply_op" ,
204+ "identity" ,
205+ "rotate" ,
206+ "flip" ,
207+ "transpose" ,
208+ "translate" ,
209+ "recolor" ,
210+ "crop" ,
211+ "pad" ,
212+ ]
0 commit comments