Skip to content

Commit cb3a2b9

Browse files
authored
Merge pull request #12 from tylerbessire/codex/outline-arc-training-and-evaluation-process
fix: normalize recolor mapping for training
2 parents fba8911 + 02b3608 commit cb3a2b9

File tree

8 files changed

+85
-15
lines changed

8 files changed

+85
-15
lines changed

AGENTS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,13 @@ class MetaCognition:
479479
Notes: Resource limits and diversity enforced
480480
```
481481

482+
```
483+
[X] Step 4.3 UPDATE - Recolor parameter mismatch fixed preventing training failures
484+
Date: 2025-09-12
485+
Test Result: pytest tests/test_recolor_fix.py passed
486+
Notes: Standardised 'mapping' parameter across heuristics; episodic loader normalises keys
487+
```
488+
482489
---
483490

484491
## Step 4.4: Final Validation

arc_solver/dsl.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,20 @@ def op_pad(a: Array, out_h: int, out_w: int) -> Array:
114114
_sem_cache: Dict[Tuple[bytes, str, Tuple[Tuple[str, Any], ...]], Array] = {}
115115

116116

117+
def _norm_params(params: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...]:
118+
"""Normalise parameters to a hashable tuple."""
119+
items: List[Tuple[str, Any]] = []
120+
for k, v in sorted(params.items()):
121+
if isinstance(v, dict):
122+
items.append((k, tuple(sorted(v.items()))))
123+
else:
124+
items.append((k, v))
125+
return tuple(items)
126+
127+
117128
def apply_op(a: Array, name: str, params: Dict[str, Any]) -> Array:
118129
"""Apply a primitive operation with semantic caching."""
119-
key = (a.tobytes(), name, tuple(sorted(params.items())))
130+
key = (a.tobytes(), name, _norm_params(params))
120131
cached = _sem_cache.get(key)
121132
if cached is not None:
122133
return cached

arc_solver/dsl_complete.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,15 @@ def resize(grid: Array, new_height: int, new_width: int, method: str = 'nearest'
9898

9999
# ================== COLOR OPERATIONS ==================
100100

101-
def recolor(grid: Array, color_map: Dict[int, int]) -> Array:
102-
"""Recolor grid according to color mapping."""
101+
def recolor(
102+
grid: Array,
103+
color_map: Dict[int, int] | None = None,
104+
mapping: Dict[int, int] | None = None,
105+
) -> Array:
106+
"""Recolor grid according to a color mapping."""
107+
mapping = mapping if mapping is not None else (color_map or {})
103108
result = grid.copy()
104-
for old_color, new_color in color_map.items():
109+
for old_color, new_color in mapping.items():
105110
result[grid == old_color] = new_color
106111
return result
107112

@@ -692,7 +697,7 @@ def get_operation_signatures() -> Dict[str, List[str]]:
692697
'crop': ['top', 'bottom', 'left', 'right'],
693698
'pad': ['top', 'bottom', 'left', 'right', 'fill_value'],
694699
'resize': ['new_height', 'new_width', 'method'],
695-
'recolor': ['color_map'],
700+
'recolor': ['mapping'],
696701
'recolor_by_position': ['position_map'],
697702
'swap_colors': ['color1', 'color2'],
698703
'dominant_color_recolor': ['target_color'],

arc_solver/enhanced_search_complete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _generate_comprehensive_parameters(self, op_name: str,
368368
for old_color in colors_in_input:
369369
for new_color in range(1, 10):
370370
if new_color != old_color:
371-
param_combinations.append({'color_map': {old_color: new_color}})
371+
param_combinations.append({"mapping": {old_color: new_color}})
372372

373373
elif op_name == 'crop':
374374
for top in range(min(H, 3)):

arc_solver/heuristics_complete.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ def detect_color_patterns(inp: Array, out: Array) -> List[List[Tuple[str, Dict[s
202202
# Direct color mapping
203203
color_map = infer_color_mapping(inp, out)
204204
if color_map and len(color_map) > 0:
205-
recolored = apply_program(inp, [('recolor', {'color_map': color_map})])
205+
recolored = apply_program(inp, [("recolor", {"mapping": color_map})])
206206
if np.array_equal(recolored, out):
207-
programs.append([('recolor', {'color_map': color_map})])
207+
programs.append([("recolor", {"mapping": color_map})])
208208

209209
# Color swapping
210210
unique_colors = np.unique(inp)
@@ -324,9 +324,9 @@ def detect_multi_step_operations(inp: Array, out: Array) -> List[List[Tuple[str,
324324
intermediate = apply_program(inp, [op])
325325
color_map = infer_color_mapping(intermediate, out)
326326
if color_map:
327-
final = apply_program(intermediate, [('recolor', {'color_map': color_map})])
327+
final = apply_program(intermediate, [("recolor", {"mapping": color_map})])
328328
if np.array_equal(final, out):
329-
programs.append([op, ('recolor', {'color_map': color_map})])
329+
programs.append([op, ("recolor", {"mapping": color_map})])
330330

331331
return programs
332332

arc_solver/neural/episodic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,19 @@ def from_dict(cls, data: Dict[str, Any]) -> "Episode":
8888
(np.array(inp, dtype=int), np.array(out, dtype=int))
8989
for inp, out in data.get("train_pairs", [])
9090
]
91+
programs: List[Program] = []
92+
for program in data.get("programs", []):
93+
prog_ops: Program = []
94+
for op, params in program:
95+
if op == "recolor":
96+
mapping = params.get("mapping") or params.get("color_map") or {}
97+
params = {"mapping": {int(k): int(v) for k, v in mapping.items()}}
98+
prog_ops.append((op, params))
99+
programs.append(prog_ops)
100+
91101
episode = cls(
92102
task_signature=data["task_signature"],
93-
programs=[
94-
[(op, params) for op, params in program]
95-
for program in data.get("programs", [])
96-
],
103+
programs=programs,
97104
task_id=data.get("task_id", ""),
98105
train_pairs=train_pairs,
99106
success_count=data.get("success_count", 1),

tests/test_beam_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_beam_search_rotation_property(grid, k):
2929

3030
def test_beam_search_no_solution():
3131
a = to_array([[0]])
32-
b = to_array([[1]])
32+
b = to_array([[1, 1]])
3333
progs, _ = beam_search([(a, b)], beam_width=3, depth=1)
3434
assert progs == []
3535

tests/test_recolor_fix.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import json
2+
from typing import Dict
3+
import sys
4+
from pathlib import Path
5+
6+
import numpy as np
7+
from hypothesis import given, strategies as st
8+
9+
sys.path.append(str(Path(__file__).parent.parent))
10+
11+
from arc_solver.grid import to_array
12+
from arc_solver.dsl import apply_program
13+
from arc_solver.heuristics_complete import detect_color_patterns
14+
from arc_solver.neural.episodic import Episode
15+
16+
17+
def test_detect_color_patterns_recolor_program() -> None:
18+
"""Heuristic recolor programs use mapping parameter."""
19+
inp = to_array([[1, 0], [0, 0]])
20+
out = to_array([[2, 0], [0, 0]])
21+
programs = detect_color_patterns(inp, out)
22+
assert [("recolor", {"mapping": {1: 2}})] in programs
23+
assert np.array_equal(apply_program(inp, programs[0]), out)
24+
25+
26+
@given(st.dictionaries(st.integers(min_value=1, max_value=9),
27+
st.integers(min_value=0, max_value=9),
28+
min_size=1, max_size=3).filter(lambda m: all(k != v for k, v in m.items())))
29+
def test_episode_recolor_roundtrip(mapping: Dict[int, int]) -> None:
30+
"""Episode serialization preserves integer recolor mappings."""
31+
src, dst = next(iter(mapping.items()))
32+
inp = to_array([[src]])
33+
out = to_array([[dst]])
34+
episode = Episode(task_signature="sig", programs=[[('recolor', {'mapping': mapping})]],
35+
train_pairs=[(inp, out)])
36+
data = json.loads(json.dumps(episode.to_dict()))
37+
loaded = Episode.from_dict(data)
38+
prog = loaded.programs[0]
39+
assert prog[0][1]['mapping'] == {int(k): int(v) for k, v in mapping.items()}
40+
assert np.array_equal(apply_program(inp, prog), out)

0 commit comments

Comments
 (0)