Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,13 @@ class MetaCognition:
Notes: Resource limits and diversity enforced
```

```
[X] Step 4.3 UPDATE - Recolor parameter mismatch fixed preventing training failures
Date: 2025-09-12
Test Result: pytest tests/test_recolor_fix.py passed
Notes: Standardised 'mapping' parameter across heuristics; episodic loader normalises keys
```

---

## Step 4.4: Final Validation
Expand Down
13 changes: 12 additions & 1 deletion arc_solver/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,20 @@ def op_pad(a: Array, out_h: int, out_w: int) -> Array:
_sem_cache: Dict[Tuple[bytes, str, Tuple[Tuple[str, Any], ...]], Array] = {}


def _norm_params(params: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...]:
"""Normalise parameters to a hashable tuple."""
items: List[Tuple[str, Any]] = []
for k, v in sorted(params.items()):
if isinstance(v, dict):
items.append((k, tuple(sorted(v.items()))))
else:
items.append((k, v))
return tuple(items)


def apply_op(a: Array, name: str, params: Dict[str, Any]) -> Array:
"""Apply a primitive operation with semantic caching."""
key = (a.tobytes(), name, tuple(sorted(params.items())))
key = (a.tobytes(), name, _norm_params(params))
cached = _sem_cache.get(key)
if cached is not None:
return cached
Expand Down
13 changes: 9 additions & 4 deletions arc_solver/dsl_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ def resize(grid: Array, new_height: int, new_width: int, method: str = 'nearest'

# ================== COLOR OPERATIONS ==================

def recolor(grid: Array, color_map: Dict[int, int]) -> Array:
"""Recolor grid according to color mapping."""
def recolor(
grid: Array,
color_map: Dict[int, int] | None = None,
mapping: Dict[int, int] | None = None,
) -> Array:
"""Recolor grid according to a color mapping."""
mapping = mapping if mapping is not None else (color_map or {})
result = grid.copy()
for old_color, new_color in color_map.items():
for old_color, new_color in mapping.items():
result[grid == old_color] = new_color
return result

Expand Down Expand Up @@ -692,7 +697,7 @@ def get_operation_signatures() -> Dict[str, List[str]]:
'crop': ['top', 'bottom', 'left', 'right'],
'pad': ['top', 'bottom', 'left', 'right', 'fill_value'],
'resize': ['new_height', 'new_width', 'method'],
'recolor': ['color_map'],
'recolor': ['mapping'],
'recolor_by_position': ['position_map'],
'swap_colors': ['color1', 'color2'],
'dominant_color_recolor': ['target_color'],
Expand Down
2 changes: 1 addition & 1 deletion arc_solver/enhanced_search_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _generate_comprehensive_parameters(self, op_name: str,
for old_color in colors_in_input:
for new_color in range(1, 10):
if new_color != old_color:
param_combinations.append({'color_map': {old_color: new_color}})
param_combinations.append({"mapping": {old_color: new_color}})

elif op_name == 'crop':
for top in range(min(H, 3)):
Expand Down
8 changes: 4 additions & 4 deletions arc_solver/heuristics_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def detect_color_patterns(inp: Array, out: Array) -> List[List[Tuple[str, Dict[s
# Direct color mapping
color_map = infer_color_mapping(inp, out)
if color_map and len(color_map) > 0:
recolored = apply_program(inp, [('recolor', {'color_map': color_map})])
recolored = apply_program(inp, [("recolor", {"mapping": color_map})])
if np.array_equal(recolored, out):
programs.append([('recolor', {'color_map': color_map})])
programs.append([("recolor", {"mapping": color_map})])

# Color swapping
unique_colors = np.unique(inp)
Expand Down Expand Up @@ -324,9 +324,9 @@ def detect_multi_step_operations(inp: Array, out: Array) -> List[List[Tuple[str,
intermediate = apply_program(inp, [op])
color_map = infer_color_mapping(intermediate, out)
if color_map:
final = apply_program(intermediate, [('recolor', {'color_map': color_map})])
final = apply_program(intermediate, [("recolor", {"mapping": color_map})])
if np.array_equal(final, out):
programs.append([op, ('recolor', {'color_map': color_map})])
programs.append([op, ("recolor", {"mapping": color_map})])

return programs

Expand Down
15 changes: 11 additions & 4 deletions arc_solver/neural/episodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,19 @@ def from_dict(cls, data: Dict[str, Any]) -> "Episode":
(np.array(inp, dtype=int), np.array(out, dtype=int))
for inp, out in data.get("train_pairs", [])
]
programs: List[Program] = []
for program in data.get("programs", []):
prog_ops: Program = []
for op, params in program:
if op == "recolor":
mapping = params.get("mapping") or params.get("color_map") or {}
params = {"mapping": {int(k): int(v) for k, v in mapping.items()}}
prog_ops.append((op, params))
programs.append(prog_ops)

episode = cls(
task_signature=data["task_signature"],
programs=[
[(op, params) for op, params in program]
for program in data.get("programs", [])
],
programs=programs,
task_id=data.get("task_id", ""),
train_pairs=train_pairs,
success_count=data.get("success_count", 1),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_beam_search_rotation_property(grid, k):

def test_beam_search_no_solution():
a = to_array([[0]])
b = to_array([[1]])
b = to_array([[1, 1]])
progs, _ = beam_search([(a, b)], beam_width=3, depth=1)
assert progs == []

Expand Down
40 changes: 40 additions & 0 deletions tests/test_recolor_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
from typing import Dict
import sys
from pathlib import Path

import numpy as np
from hypothesis import given, strategies as st

sys.path.append(str(Path(__file__).parent.parent))

from arc_solver.grid import to_array
from arc_solver.dsl import apply_program
from arc_solver.heuristics_complete import detect_color_patterns
from arc_solver.neural.episodic import Episode


def test_detect_color_patterns_recolor_program() -> None:
"""Heuristic recolor programs use mapping parameter."""
inp = to_array([[1, 0], [0, 0]])
out = to_array([[2, 0], [0, 0]])
programs = detect_color_patterns(inp, out)
assert [("recolor", {"mapping": {1: 2}})] in programs
assert np.array_equal(apply_program(inp, programs[0]), out)


@given(st.dictionaries(st.integers(min_value=1, max_value=9),
st.integers(min_value=0, max_value=9),
min_size=1, max_size=3).filter(lambda m: all(k != v for k, v in m.items())))
def test_episode_recolor_roundtrip(mapping: Dict[int, int]) -> None:
"""Episode serialization preserves integer recolor mappings."""
src, dst = next(iter(mapping.items()))
inp = to_array([[src]])
out = to_array([[dst]])
episode = Episode(task_signature="sig", programs=[[('recolor', {'mapping': mapping})]],
train_pairs=[(inp, out)])
data = json.loads(json.dumps(episode.to_dict()))
loaded = Episode.from_dict(data)
prog = loaded.programs[0]
assert prog[0][1]['mapping'] == {int(k): int(v) for k, v in mapping.items()}
assert np.array_equal(apply_program(inp, prog), out)
Loading