Skip to content

Commit 0521227

Browse files
authored
Merge pull request #13 from tylerbessire/codex/outline-arc-training-and-evaluation-process-ngi3f3
fix: standardize translate fill parameter
2 parents cb3a2b9 + c18f264 commit 0521227

File tree

5 files changed

+63
-13
lines changed

5 files changed

+63
-13
lines changed

AGENTS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,13 @@ class MetaCognition:
484484
Date: 2025-09-12
485485
Test Result: pytest tests/test_recolor_fix.py passed
486486
Notes: Standardised 'mapping' parameter across heuristics; episodic loader normalises keys
487+
488+
[X] Step 4.3 UPDATE2 - Translate parameter mismatch fixed preventing training warnings
489+
Date: 2025-09-13
490+
Test Result: pytest tests/test_translate_fix.py passed; python tools/train_guidance_on_arc.py --epochs 1
491+
Notes: Canonicalised 'fill' parameter for translate; legacy 'fill_value' still accepted
492+
493+
487494
```
488495

489496
---

arc_solver/dsl.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def op_transpose(a: Array) -> Array:
6060
return transpose_grid(a)
6161

6262

63-
def op_translate(a: Array, dy: int, dx: int, fill: Optional[int] = None) -> Array:
63+
def op_translate(a: Array, dy: int, dx: int, fill: Optional[int] = None, *, fill_value: Optional[int] = None) -> Array:
6464
"""Translate the grid by ``(dy, dx)`` filling uncovered cells.
6565
6666
Parameters
@@ -69,11 +69,13 @@ def op_translate(a: Array, dy: int, dx: int, fill: Optional[int] = None) -> Arra
6969
Input grid.
7070
dy, dx:
7171
Translation offsets. Positive values move content down/right.
72-
fill:
73-
Optional fill value for uncovered cells. If ``None`` the background
74-
colour of ``a`` is used.
72+
fill, fill_value:
73+
Optional fill value for uncovered cells. ``fill_value`` is an alias for
74+
backward compatibility. When both are ``None`` the background colour of
75+
``a`` is used.
7576
"""
76-
fill_val = 0 if fill is None else fill
77+
chosen = fill if fill is not None else fill_value
78+
fill_val = 0 if chosen is None else chosen
7779
return translate_grid(a, dy, dx, fill=fill_val)
7880

7981

@@ -114,6 +116,20 @@ def op_pad(a: Array, out_h: int, out_w: int) -> Array:
114116
_sem_cache: Dict[Tuple[bytes, str, Tuple[Tuple[str, Any], ...]], Array] = {}
115117

116118

119+
def _canonical_params(name: str, params: Dict[str, Any]) -> Dict[str, Any]:
120+
"""Return a copy of ``params`` with legacy aliases normalised."""
121+
if name == "recolor" and "mapping" not in params and "color_map" in params:
122+
new_params = dict(params)
123+
new_params["mapping"] = new_params.pop("color_map")
124+
return new_params
125+
if name == "translate" and "fill" not in params and "fill_value" in params:
126+
new_params = dict(params)
127+
new_params["fill"] = new_params.pop("fill_value")
128+
return new_params
129+
return params
130+
131+
132+
117133
def _norm_params(params: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...]:
118134
"""Normalise parameters to a hashable tuple."""
119135
items: List[Tuple[str, Any]] = []
@@ -127,6 +143,7 @@ def _norm_params(params: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...]:
127143

128144
def apply_op(a: Array, name: str, params: Dict[str, Any]) -> Array:
129145
"""Apply a primitive operation with semantic caching."""
146+
params = _canonical_params(name, params)
130147
key = (a.tobytes(), name, _norm_params(params))
131148
cached = _sem_cache.get(key)
132149
if cached is not None:

arc_solver/dsl_complete.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ def transpose(grid: Array) -> Array:
3939
return grid.T
4040

4141

42-
def translate(grid: Array, dx: int = 0, dy: int = 0, fill_value: int = 0) -> Array:
43-
"""Translate grid by (dx, dy) with wraparound or filling."""
42+
def translate(grid: Array, dy: int = 0, dx: int = 0, fill: int = 0, *, fill_value: int | None = None) -> Array:
43+
"""Translate grid by (dy, dx) with wraparound or filling.
44+
45+
``fill_value`` is accepted for backward compatibility.
46+
"""
47+
if fill_value is not None and fill == 0:
48+
fill = fill_value
4449
H, W = grid.shape
45-
result = np.full_like(grid, fill_value)
50+
result = np.full_like(grid, fill)
4651

4752
# Source bounds
4853
src_y_start = max(0, -dy)
4954
src_y_end = min(H, H - dy)
50-
src_x_start = max(0, -dx)
55+
src_x_start = max(0, -dx)
5156
src_x_end = min(W, W - dx)
5257

5358
# Destination bounds
@@ -693,7 +698,7 @@ def get_operation_signatures() -> Dict[str, List[str]]:
693698
'rotate': ['k'],
694699
'flip': ['axis'],
695700
'transpose': [],
696-
'translate': ['dx', 'dy', 'fill_value'],
701+
'translate': ['dy', 'dx', 'fill'],
697702
'crop': ['top', 'bottom', 'left', 'right'],
698703
'pad': ['top', 'bottom', 'left', 'right', 'fill_value'],
699704
'resize': ['new_height', 'new_width', 'method'],

arc_solver/heuristics_complete.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def detect_basic_transformations(inp: Array, out: Array) -> List[List[Tuple[str,
104104

105105
translated = translate_safe(inp, dx, dy)
106106
if np.array_equal(translated, out):
107-
programs.append([('translate', {'dx': dx, 'dy': dy, 'fill_value': 0})])
107+
programs.append([('translate', {'dy': dy, 'dx': dx, 'fill': 0})])
108108

109109
return programs
110110

@@ -497,8 +497,8 @@ def infer_color_mapping(inp: Array, out: Array) -> Optional[Dict[int, int]]:
497497

498498
for i in range(inp.shape[0]):
499499
for j in range(inp.shape[1]):
500-
inp_color = inp[i, j]
501-
out_color = out[i, j]
500+
inp_color = int(inp[i, j])
501+
out_color = int(out[i, j])
502502

503503
if inp_color in color_map:
504504
if color_map[inp_color] != out_color:

tests/test_translate_fix.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from pathlib import Path
2+
import sys
3+
4+
import numpy as np
5+
6+
sys.path.append(str(Path(__file__).resolve().parents[1]))
7+
from arc_solver.dsl import apply_program
8+
from arc_solver.heuristics_complete import detect_basic_transformations
9+
10+
11+
def test_translate_fill_value_alias_and_detection():
12+
inp = np.array([[1, 0], [0, 0]])
13+
expected = np.array([[0, 0], [0, 1]])
14+
15+
# legacy alias still works
16+
legacy_prog = [("translate", {"dy": 1, "dx": 1, "fill_value": 0})]
17+
assert np.array_equal(apply_program(inp, legacy_prog), expected)
18+
19+
# heuristics now emit canonical parameter name
20+
detected = detect_basic_transformations(inp, expected)
21+
assert [("translate", {"dy": 1, "dx": 1, "fill": 0})] in detected

0 commit comments

Comments
 (0)