Skip to content

Commit a579ef9

Browse files
authored
Merge pull request #1 from tylerbessire/codex/review-code-for-arc-2025-readiness
Harden submission runner with deterministic budgets
2 parents c613fae + ad73bf3 commit a579ef9

File tree

9 files changed

+379
-56
lines changed

9 files changed

+379
-56
lines changed

arc_solver/canonical.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Canonicalisation utilities for ARC grids.
2+
3+
This module provides functions to normalise grids under the D4 symmetry group
4+
(rotations and reflections) and canonicalise colour labels. Canonicalisation
5+
reduces the search space by treating symmetric grids as identical.
6+
"""
7+
from __future__ import annotations
8+
9+
from typing import Dict, Iterable, Tuple
10+
11+
import numpy as np
12+
13+
Array = np.ndarray
14+
15+
# Precompute the eight transformations in the D4 symmetry group.
16+
D4: Tuple[callable, ...] = (
17+
lambda g: g,
18+
lambda g: np.rot90(g, 1),
19+
lambda g: np.rot90(g, 2),
20+
lambda g: np.rot90(g, 3),
21+
lambda g: np.flipud(g),
22+
lambda g: np.fliplr(g),
23+
lambda g: np.rot90(np.flipud(g), 1),
24+
lambda g: np.rot90(np.fliplr(g), 1),
25+
)
26+
27+
28+
def canonicalize_colors(grid: Array) -> Tuple[Array, Dict[int, int]]:
29+
"""Relabel colours in ``grid`` in descending frequency order.
30+
31+
Parameters
32+
----------
33+
grid:
34+
Input array containing integer colour labels.
35+
36+
Returns
37+
-------
38+
canonical:
39+
Array with colours mapped to ``0..n-1`` in frequency order.
40+
mapping:
41+
Dictionary mapping original colours to canonical labels.
42+
43+
Raises
44+
------
45+
TypeError
46+
If ``grid`` is not a ``numpy.ndarray`` or is not of integer type.
47+
"""
48+
if not isinstance(grid, np.ndarray):
49+
raise TypeError("grid must be a numpy array")
50+
if not np.issubdtype(grid.dtype, np.integer):
51+
raise TypeError("grid dtype must be integer")
52+
53+
vals, counts = np.unique(grid, return_counts=True)
54+
order = [int(v) for v, _ in sorted(zip(vals, counts), key=lambda t: (-t[1], t[0]))]
55+
mapping = {c: i for i, c in enumerate(order)}
56+
vect_map = np.vectorize(mapping.get)
57+
canonical = vect_map(grid)
58+
return canonical.astype(np.int16), mapping
59+
60+
61+
def canonicalize_D4(grid: Array) -> Array:
62+
"""Return the lexicographically smallest grid under D4 symmetries.
63+
64+
The grid is first transformed by each D4 element, then colour-canonicalised.
65+
The transformation with the smallest shape and byte representation is chosen
66+
as the canonical representative.
67+
68+
Parameters
69+
----------
70+
grid:
71+
Input array to canonicalise.
72+
73+
Returns
74+
-------
75+
np.ndarray
76+
Canonicalised grid.
77+
78+
Raises
79+
------
80+
TypeError
81+
If ``grid`` is not a ``numpy.ndarray`` or is not of integer type.
82+
"""
83+
if not isinstance(grid, np.ndarray):
84+
raise TypeError("grid must be a numpy array")
85+
if not np.issubdtype(grid.dtype, np.integer):
86+
raise TypeError("grid dtype must be integer")
87+
88+
best: Array | None = None
89+
best_key: Tuple[Tuple[int, int], bytes] | None = None
90+
for transform in D4:
91+
transformed = transform(grid)
92+
canonical, _ = canonicalize_colors(transformed)
93+
key = (canonical.shape, canonical.tobytes())
94+
if best_key is None or key < best_key:
95+
best, best_key = canonical, key
96+
if best is None:
97+
# This should not occur because D4 contains identity, but guard anyway.
98+
return grid.copy()
99+
return best

arc_solver/enhanced_search.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -236,20 +236,19 @@ def save_components(self):
236236
self.sketch_miner.save_sketches("sketches.json")
237237

238238

239-
def predict_two_enhanced(progs: List[List[Tuple[str, Dict[str, int]]]],
240-
test_inputs: List[Array]) -> List[List[Array]]:
239+
def predict_two_enhanced(
240+
progs: List[List[Tuple[str, Dict[str, int]]]],
241+
test_inputs: List[Array],
242+
prefer_diverse: bool = False,
243+
) -> List[List[Array]]:
241244
"""Enhanced prediction with better fallback strategies."""
242-
if len(progs) == 0:
243-
# No programs found, use identity
245+
if not progs:
244246
picks = [[("identity", {})], [("identity", {})]]
245-
elif len(progs) == 1:
246-
# Only one program, use it twice with slight variation if possible
247-
main_prog = progs[0]
248-
picks = [main_prog, main_prog]
247+
elif prefer_diverse and len(progs) > 1:
248+
picks = [progs[0], progs[1]]
249249
else:
250-
# Use top 2 programs
251-
picks = progs[:2]
252-
250+
picks = progs[:2] if len(progs) >= 2 else [progs[0], progs[0]]
251+
253252
attempts: List[List[Array]] = []
254253
for program in picks:
255254
outs: List[Array] = []
@@ -258,25 +257,26 @@ def predict_two_enhanced(progs: List[List[Tuple[str, Dict[str, int]]]],
258257
result = apply_program(ti, program)
259258
outs.append(result)
260259
except Exception:
261-
# Fallback to identity on failure
262260
outs.append(ti)
263261
attempts.append(outs)
264-
262+
265263
return attempts
266264

267265

268266
# Integration function to use enhanced search in the main solver
269-
def synthesize_with_enhancements(train_pairs: List[Tuple[Array, Array]],
270-
max_programs: int = 256) -> List[List[Tuple[str, Dict[str, int]]]]:
267+
def synthesize_with_enhancements(
268+
train_pairs: List[Tuple[Array, Array]],
269+
max_programs: int = 256,
270+
force_alt: bool = False,
271+
) -> List[List[Tuple[str, Dict[str, int]]]]:
271272
"""Main function to synthesize programs with all enhancements."""
272-
273-
# Initialize enhanced search (this will be cached across calls in practice)
273+
274274
enhanced_search = EnhancedSearch()
275-
276-
# Try enhanced synthesis
277275
programs = enhanced_search.synthesize_enhanced(train_pairs, max_programs)
278-
279-
# Save learned components periodically
276+
277+
if force_alt and len(programs) > 1:
278+
programs = programs[1:]
279+
280280
enhanced_search.save_components()
281-
281+
282282
return programs

arc_solver/enhanced_solver.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,22 @@
88

99
from __future__ import annotations
1010

11-
from typing import Dict, List, Tuple
11+
from typing import Any, Dict, List, Optional, Tuple
1212
import numpy as np
1313
import os
1414

1515
from .grid import to_array, to_list, Array
16-
from .search import synthesize, predict_two # Keep original as fallback
16+
from .search import (
17+
synthesize as synth_baseline,
18+
predict_two as predict_two_baseline,
19+
)
1720
from .enhanced_search import synthesize_with_enhancements, predict_two_enhanced
1821

1922

2023
class ARCSolver:
2124
"""Enhanced ARC solver with neural components and episodic memory."""
2225

23-
def __init__(self, use_enhancements: bool = True,
26+
def __init__(self, use_enhancements: bool = True,
2427
guidance_model_path: str = None,
2528
episode_db_path: str = "episodes.json"):
2629
self.use_enhancements = use_enhancements
@@ -32,6 +35,7 @@ def __init__(self, use_enhancements: bool = True,
3235
'enhancement_success_rate': 0.0,
3336
'fallback_used': 0,
3437
}
38+
self._last_outputs: Optional[Tuple[List[List[List[int]]], List[List[List[int]]]]] = None
3539

3640
def solve_task(self, task: Dict[str, List[Dict[str, List[List[int]]]]]) -> Dict[str, List[List[List[int]]]]:
3741
"""Solve a single ARC task using enhanced or baseline methods."""
@@ -71,14 +75,77 @@ def solve_task(self, task: Dict[str, List[Dict[str, List[List[int]]]]]) -> Dict[
7175

7276
except Exception:
7377
# Fall back to baseline approach
74-
progs = synthesize(train_pairs)
75-
attempts = predict_two(progs, test_inputs)
78+
progs = synth_baseline(train_pairs)
79+
attempts = predict_two_baseline(progs, test_inputs)
7680

7781
# Convert outputs back to nested lists
7882
return {
7983
"attempt_1": [to_list(arr) for arr in attempts[0]],
8084
"attempt_2": [to_list(arr) for arr in attempts[1]],
8185
}
86+
87+
def solve_task_two_attempts(
88+
self, task: Dict[str, List[Dict[str, List[List[int]]]]]
89+
) -> Tuple[List[List[List[int]]], List[List[List[int]]]]:
90+
"""Solve a task and ensure two diverse attempts.
91+
92+
Args:
93+
task: ARC task specification.
94+
95+
Returns:
96+
A tuple ``(attempt1, attempt2)`` each being a list of output grids
97+
corresponding to the test inputs.
98+
"""
99+
100+
result = self.solve_task(task)
101+
attempt1 = result["attempt_1"]
102+
attempt2 = result["attempt_2"]
103+
104+
if attempt1 == attempt2:
105+
alt = self._second_pass_diversified(task)
106+
if alt is not None:
107+
attempt2 = alt
108+
109+
self._last_outputs = (attempt1, attempt2)
110+
return attempt1, attempt2
111+
112+
def _second_pass_diversified(
113+
self, task: Dict[str, List[Dict[str, List[List[int]]]]]
114+
) -> Optional[List[List[List[int]]]]:
115+
"""Run a diversified second search pass to obtain an alternative output."""
116+
117+
train_pairs = [
118+
(to_array(p["input"]), to_array(p["output"])) for p in task["train"]
119+
]
120+
test_inputs = [to_array(p["input"]) for p in task["test"]]
121+
122+
try:
123+
programs = synthesize_with_enhancements(train_pairs, force_alt=True)
124+
attempts = predict_two_enhanced(programs, test_inputs, prefer_diverse=True)
125+
return [to_list(x) for x in attempts[0]]
126+
except Exception:
127+
try:
128+
programs = synth_baseline(train_pairs)
129+
attempts = predict_two_baseline(
130+
programs, test_inputs, prefer_diverse=True
131+
)
132+
return [to_list(x) for x in attempts[0]]
133+
except Exception:
134+
return None
135+
136+
def best_so_far(
137+
self, task: Dict[str, List[Dict[str, List[List[int]]]]]
138+
) -> List[List[List[int]]]:
139+
"""Return the best outputs computed so far for the current task.
140+
141+
If the solver has produced at least one attempt, that attempt is
142+
returned. Otherwise, the identity transformation of the first test
143+
input is used as a safe fallback.
144+
"""
145+
146+
if self._last_outputs is not None:
147+
return self._last_outputs[0]
148+
return [task["test"][0]["input"]]
82149

83150
def _validate_solution(self, attempts: List[List[Array]], test_inputs: List[Array]) -> bool:
84151
"""Basic validation to check if solution seems reasonable."""

arc_solver/features.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .grid import Array, histogram, bg_color, eq
1515
from .objects import connected_components
16+
from .canonical import canonicalize_D4
1617

1718

1819
def extract_task_features(train_pairs: List[Tuple[Array, Array]]) -> Dict[str, Any]:
@@ -21,7 +22,15 @@ def extract_task_features(train_pairs: List[Tuple[Array, Array]]) -> Dict[str, A
2122
These features capture task-level properties that can help predict which
2223
DSL operations are likely to be relevant for solving the task.
2324
"""
24-
features = {}
25+
try:
26+
train_pairs = [
27+
(canonicalize_D4(inp), canonicalize_D4(out))
28+
for inp, out in train_pairs
29+
]
30+
except TypeError as exc:
31+
raise ValueError(f"invalid grid in train_pairs: {exc}") from exc
32+
33+
features: Dict[str, Any] = {}
2534

2635
# Basic grid statistics
2736
input_shapes = [inp.shape for inp, _ in train_pairs]

arc_solver/objects.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import List, Tuple, Dict, Any
1414

1515
from .grid import Array, bg_color
16+
from .canonical import canonicalize_D4
1617

1718

1819
def neighbors4(y: int, x: int) -> List[Tuple[int, int]]:
@@ -21,14 +22,21 @@ def neighbors4(y: int, x: int) -> List[Tuple[int, int]]:
2122

2223

2324
def connected_components(a: Array) -> List[Dict[str, Any]]:
24-
"""Find all 4-connected components in the grid, grouping by exact color.
25+
"""Find all 4-connected components in a canonicalised grid.
2526
26-
Each component dictionary contains:
27+
The input grid is first normalised under D4 symmetries and colour
28+
relabelling to ensure deterministic component extraction. Each component
29+
dictionary contains:
2730
- color: the color value of the component
2831
- bbox: (top, left, height, width) of the bounding box
2932
- mask: a 2D array of shape (height, width) with the component values
3033
- pixels: list of (row, col) indices in original grid
3134
"""
35+
try:
36+
a = canonicalize_D4(a)
37+
except TypeError as exc:
38+
raise ValueError(f"invalid grid: {exc}") from exc
39+
3240
h, w = a.shape
3341
visited = np.zeros_like(a, dtype=bool)
3442
comps: List[Dict[str, Any]] = []
@@ -65,11 +73,13 @@ def connected_components(a: Array) -> List[Dict[str, Any]]:
6573

6674

6775
def infer_symmetries(a: Array) -> Dict[str, bool]:
68-
"""Return a simple dictionary of possible symmetries in the grid.
76+
"""Return a dictionary of potential symmetries for a canonicalised grid."""
77+
try:
78+
a = canonicalize_D4(a)
79+
except TypeError as exc:
80+
raise ValueError(f"invalid grid: {exc}") from exc
6981

70-
For speed, this function does not check each symmetry but sets flags to True.
71-
More precise symmetry detection can be incorporated later based on heuristics.
72-
"""
82+
# Placeholder flags; symmetry detection can be refined with heuristics.
7383
return {
7484
"rot90": True,
7585
"rot180": True,

0 commit comments

Comments
 (0)