Skip to content

Commit 756753c

Browse files
committed
Add semantic cache and full DSL operation wrappers
1 parent c613fae commit 756753c

File tree

11 files changed

+544
-106
lines changed

11 files changed

+544
-106
lines changed

arc_solver/canonical.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Canonicalisation utilities for ARC grids.
2+
3+
This module provides functions to normalise grids under the D4 symmetry group
4+
(rotations and reflections) and canonicalise colour labels. Canonicalisation
5+
reduces the search space by treating symmetric grids as identical.
6+
"""
7+
from __future__ import annotations
8+
9+
from typing import Dict, Iterable, Tuple
10+
11+
import numpy as np
12+
13+
Array = np.ndarray
14+
15+
# Precompute the eight transformations in the D4 symmetry group.
16+
D4: Tuple[callable, ...] = (
17+
lambda g: g,
18+
lambda g: np.rot90(g, 1),
19+
lambda g: np.rot90(g, 2),
20+
lambda g: np.rot90(g, 3),
21+
lambda g: np.flipud(g),
22+
lambda g: np.fliplr(g),
23+
lambda g: np.rot90(np.flipud(g), 1),
24+
lambda g: np.rot90(np.fliplr(g), 1),
25+
)
26+
27+
28+
def canonicalize_colors(grid: Array) -> Tuple[Array, Dict[int, int]]:
29+
"""Relabel colours in ``grid`` in descending frequency order.
30+
31+
Parameters
32+
----------
33+
grid:
34+
Input array containing integer colour labels.
35+
36+
Returns
37+
-------
38+
canonical:
39+
Array with colours mapped to ``0..n-1`` in frequency order.
40+
mapping:
41+
Dictionary mapping original colours to canonical labels.
42+
43+
Raises
44+
------
45+
TypeError
46+
If ``grid`` is not a ``numpy.ndarray`` or is not of integer type.
47+
"""
48+
if not isinstance(grid, np.ndarray):
49+
raise TypeError("grid must be a numpy array")
50+
if not np.issubdtype(grid.dtype, np.integer):
51+
raise TypeError("grid dtype must be integer")
52+
53+
vals, counts = np.unique(grid, return_counts=True)
54+
order = [int(v) for v, _ in sorted(zip(vals, counts), key=lambda t: (-t[1], t[0]))]
55+
mapping = {c: i for i, c in enumerate(order)}
56+
vect_map = np.vectorize(mapping.get)
57+
canonical = vect_map(grid)
58+
return canonical.astype(np.int16), mapping
59+
60+
61+
def canonicalize_D4(grid: Array) -> Array:
62+
"""Return the lexicographically smallest grid under D4 symmetries.
63+
64+
The grid is first transformed by each D4 element, then colour-canonicalised.
65+
The transformation with the smallest shape and byte representation is chosen
66+
as the canonical representative.
67+
68+
Parameters
69+
----------
70+
grid:
71+
Input array to canonicalise.
72+
73+
Returns
74+
-------
75+
np.ndarray
76+
Canonicalised grid.
77+
78+
Raises
79+
------
80+
TypeError
81+
If ``grid`` is not a ``numpy.ndarray`` or is not of integer type.
82+
"""
83+
if not isinstance(grid, np.ndarray):
84+
raise TypeError("grid must be a numpy array")
85+
if not np.issubdtype(grid.dtype, np.integer):
86+
raise TypeError("grid dtype must be integer")
87+
88+
best: Array | None = None
89+
best_key: Tuple[Tuple[int, int], bytes] | None = None
90+
for transform in D4:
91+
transformed = transform(grid)
92+
canonical, _ = canonicalize_colors(transformed)
93+
key = (canonical.shape, canonical.tobytes())
94+
if best_key is None or key < best_key:
95+
best, best_key = canonical, key
96+
if best is None:
97+
# This should not occur because D4 contains identity, but guard anyway.
98+
return grid.copy()
99+
return best

arc_solver/dsl.py

Lines changed: 148 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,212 @@
1-
"""
2-
Domain-specific language (DSL) primitives for ARC program synthesis.
1+
"""Domain-specific language (DSL) primitives for ARC program synthesis.
32
4-
This module defines a small set of composable operations that act on grids.
5-
Each operation is represented by an `Op` object with a name, a function, and
6-
metadata about its parameters. Programs are sequences of these operations.
3+
This module defines a set of composable operations that act on grids. Each
4+
operation is represented by an :class:`Op` and registered in :data:`OPS`.
5+
Programs are sequences of these operations applied to a grid.
76
"""
87

98
from __future__ import annotations
109

10+
from typing import Any, Callable, Dict, List, Tuple, Optional
1111
import numpy as np
12-
from typing import Any, Callable, Dict, List, Tuple
1312

14-
from .grid import Array, rotate90, flip, transpose, translate, color_map, crop, pad_to, bg_color
13+
from arc_solver.grid import (
14+
Array,
15+
rotate90,
16+
flip as flip_grid,
17+
transpose as transpose_grid,
18+
translate as translate_grid,
19+
color_map as color_map_grid,
20+
crop as crop_array,
21+
pad_to,
22+
bg_color,
23+
)
1524

1625

1726
class Op:
18-
"""Represents a primitive transformation on a grid.
19-
20-
Attributes
21-
----------
22-
name : str
23-
Human-readable name of the operation.
24-
fn : Callable
25-
Function implementing the operation.
26-
arity : int
27-
Number of input grids (arity=1 for single-grid ops).
28-
param_names : List[str]
29-
Names of parameters accepted by the operation.
30-
"""
27+
"""Represents a primitive transformation on a grid."""
3128

3229
def __init__(self, name: str, fn: Callable[..., Array], arity: int, param_names: List[str]):
3330
self.name = name
3431
self.fn = fn
3532
self.arity = arity
3633
self.param_names = param_names
3734

38-
def __call__(self, *args, **kwargs) -> Array:
35+
def __call__(self, *args: Any, **kwargs: Any) -> Array:
3936
return self.fn(*args, **kwargs)
4037

4138

42-
# Primitive operations (single-grid)
39+
# ---------------------------------------------------------------------------
40+
# Primitive operation implementations
41+
# ---------------------------------------------------------------------------
42+
4343
def op_identity(a: Array) -> Array:
44-
return a
44+
"""Return a copy of the input grid."""
45+
return a.copy()
4546

4647

4748
def op_rotate(a: Array, k: int) -> Array:
48-
return rotate90(a, k)
49+
"""Rotate grid by ``k`` quarter turns clockwise."""
50+
return rotate90(a, -k)
4951

5052

5153
def op_flip(a: Array, axis: int) -> Array:
52-
return flip(a, axis)
54+
"""Flip grid along the specified axis (0=vertical, 1=horizontal)."""
55+
return flip_grid(a, axis)
5356

5457

5558
def op_transpose(a: Array) -> Array:
56-
return transpose(a)
59+
"""Transpose the grid."""
60+
return transpose_grid(a)
5761

5862

59-
def op_translate(a: Array, dy: int, dx: int) -> Array:
60-
return translate(a, dy, dx, fill=bg_color(a))
63+
def op_translate(a: Array, dy: int, dx: int, fill: Optional[int] = None) -> Array:
64+
"""Translate the grid by ``(dy, dx)`` filling uncovered cells.
65+
66+
Parameters
67+
----------
68+
a:
69+
Input grid.
70+
dy, dx:
71+
Translation offsets. Positive values move content down/right.
72+
fill:
73+
Optional fill value for uncovered cells. If ``None`` the background
74+
colour of ``a`` is used.
75+
"""
76+
fill_val = 0 if fill is None else fill
77+
return translate_grid(a, dy, dx, fill=fill_val)
6178

6279

6380
def op_recolor(a: Array, mapping: Dict[int, int]) -> Array:
64-
return color_map(a, mapping)
81+
"""Recolour grid according to a mapping from old to new colours."""
82+
return color_map_grid(a, mapping)
6583

6684

6785
def op_crop_bbox(a: Array, top: int, left: int, height: int, width: int) -> Array:
68-
# ensure cropping stays inside bounds
86+
"""Crop a bounding box from the grid ensuring bounds are valid."""
6987
h, w = a.shape
7088
top = max(0, min(top, h - 1))
7189
left = max(0, min(left, w - 1))
7290
height = max(1, min(height, h - top))
7391
width = max(1, min(width, w - left))
74-
return crop(a, top, left, height, width)
92+
return crop_array(a, top, left, height, width)
7593

7694

7795
def op_pad(a: Array, out_h: int, out_w: int) -> Array:
96+
"""Pad grid to a specific height and width using background colour."""
7897
return pad_to(a, (out_h, out_w), fill=bg_color(a))
7998

8099

81-
# Register operations in a dictionary for easy lookup
100+
# Registry of primitive operations ---------------------------------------------------------
82101
OPS: Dict[str, Op] = {
83102
"identity": Op("identity", op_identity, 1, []),
84103
"rotate": Op("rotate", op_rotate, 1, ["k"]),
85104
"flip": Op("flip", op_flip, 1, ["axis"]),
86105
"transpose": Op("transpose", op_transpose, 1, []),
87-
"translate": Op("translate", op_translate, 1, ["dy", "dx"]),
106+
"translate": Op("translate", op_translate, 1, ["dy", "dx", "fill"]),
88107
"recolor": Op("recolor", op_recolor, 1, ["mapping"]),
89108
"crop": Op("crop", op_crop_bbox, 1, ["top", "left", "height", "width"]),
90109
"pad": Op("pad", op_pad, 1, ["out_h", "out_w"]),
91110
}
92111

93112

94-
def apply_program(a: Array, program: List[Tuple[str, Dict[str, Any]]]) -> Array:
95-
"""Apply a sequence of operations (program) to the input array.
113+
# Semantic cache -------------------------------------------------------------------------
114+
_sem_cache: Dict[Tuple[bytes, str, Tuple[Tuple[str, Any], ...]], Array] = {}
96115

97-
Parameters
98-
----------
99-
a : Array
100-
Input grid.
101-
program : List of (op_name, params)
102-
Sequence of operations with parameters. The operations are looked up in
103-
OPS.
104-
105-
Returns
106-
-------
107-
Array
108-
Resulting grid after applying the program.
116+
117+
def apply_op(a: Array, name: str, params: Dict[str, Any]) -> Array:
118+
"""Apply a primitive operation with semantic caching."""
119+
key = (a.tobytes(), name, tuple(sorted(params.items())))
120+
cached = _sem_cache.get(key)
121+
if cached is not None:
122+
return cached
123+
op = OPS[name]
124+
out = op(a, **params)
125+
_sem_cache[key] = out
126+
return out
127+
128+
129+
# User-facing convenience wrappers --------------------------------------------------------
130+
131+
def identity(a: Array) -> Array:
132+
"""Return a copy of the input grid."""
133+
return op_identity(a)
134+
135+
136+
def rotate(a: Array, k: int) -> Array:
137+
"""Rotate grid by ``k`` quarter turns clockwise."""
138+
return op_rotate(a, k)
139+
140+
141+
def flip(a: Array, axis: int) -> Array:
142+
"""Flip grid along the specified axis."""
143+
return op_flip(a, axis)
144+
145+
146+
def transpose(a: Array) -> Array:
147+
"""Transpose the grid."""
148+
return op_transpose(a)
149+
150+
151+
def translate(a: Array, dx: int, dy: int, fill_value: Optional[int] = None) -> Array:
152+
"""Translate grid by ``(dy, dx)`` with optional fill value."""
153+
return op_translate(a, dy, dx, fill=fill_value)
154+
155+
156+
def recolor(a: Array, color_map: Dict[int, int]) -> Array:
157+
"""Recolour grid according to a mapping."""
158+
return op_recolor(a, color_map)
159+
160+
161+
def crop(a: Array, top: int, bottom: int, left: int, right: int) -> Array:
162+
"""Crop a region specified by inclusive-exclusive bounds.
163+
164+
Args:
165+
top, bottom, left, right: Bounds following Python slice semantics where
166+
``bottom`` and ``right`` are exclusive.
109167
"""
168+
if bottom <= top or right <= left:
169+
raise ValueError("Invalid crop bounds")
170+
h, w = a.shape
171+
top = max(0, min(top, h))
172+
bottom = max(top, min(bottom, h))
173+
left = max(0, min(left, w))
174+
right = max(left, min(right, w))
175+
return a[top:bottom, left:right].copy()
176+
177+
178+
def pad(a: Array, top: int, bottom: int, left: int, right: int, fill_value: int = 0) -> Array:
179+
"""Pad grid with ``fill_value`` on each side."""
180+
if min(top, bottom, left, right) < 0:
181+
raise ValueError("Pad widths must be non-negative")
182+
h, w = a.shape
183+
out = np.full((h + top + bottom, w + left + right), fill_value, dtype=a.dtype)
184+
out[top:top + h, left:left + w] = a
185+
return out
186+
187+
188+
# Program application --------------------------------------------------------------------
189+
190+
def apply_program(a: Array, program: List[Tuple[str, Dict[str, Any]]]) -> Array:
191+
"""Apply a sequence of operations to the input grid."""
110192
out = a
111193
for name, params in program:
112-
op = OPS[name]
113-
out = op(out, **params)
114-
return out
194+
out = apply_op(out, name, params)
195+
return out
196+
197+
198+
__all__ = [
199+
"Array",
200+
"Op",
201+
"OPS",
202+
"apply_program",
203+
"apply_op",
204+
"identity",
205+
"rotate",
206+
"flip",
207+
"transpose",
208+
"translate",
209+
"recolor",
210+
"crop",
211+
"pad",
212+
]

0 commit comments

Comments
 (0)