|
| 1 | +import json |
| 2 | +from typing import Dict |
| 3 | +import sys |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from hypothesis import given, strategies as st |
| 8 | + |
| 9 | +sys.path.append(str(Path(__file__).parent.parent)) |
| 10 | + |
| 11 | +from arc_solver.grid import to_array |
| 12 | +from arc_solver.dsl import apply_program |
| 13 | +from arc_solver.heuristics_complete import detect_color_patterns |
| 14 | +from arc_solver.neural.episodic import Episode |
| 15 | + |
| 16 | + |
| 17 | +def test_detect_color_patterns_recolor_program() -> None: |
| 18 | + """Heuristic recolor programs use mapping parameter.""" |
| 19 | + inp = to_array([[1, 0], [0, 0]]) |
| 20 | + out = to_array([[2, 0], [0, 0]]) |
| 21 | + programs = detect_color_patterns(inp, out) |
| 22 | + assert [("recolor", {"mapping": {1: 2}})] in programs |
| 23 | + assert np.array_equal(apply_program(inp, programs[0]), out) |
| 24 | + |
| 25 | + |
| 26 | +@given(st.dictionaries(st.integers(min_value=1, max_value=9), |
| 27 | + st.integers(min_value=0, max_value=9), |
| 28 | + min_size=1, max_size=3).filter(lambda m: all(k != v for k, v in m.items()))) |
| 29 | +def test_episode_recolor_roundtrip(mapping: Dict[int, int]) -> None: |
| 30 | + """Episode serialization preserves integer recolor mappings.""" |
| 31 | + src, dst = next(iter(mapping.items())) |
| 32 | + inp = to_array([[src]]) |
| 33 | + out = to_array([[dst]]) |
| 34 | + episode = Episode(task_signature="sig", programs=[[('recolor', {'mapping': mapping})]], |
| 35 | + train_pairs=[(inp, out)]) |
| 36 | + data = json.loads(json.dumps(episode.to_dict())) |
| 37 | + loaded = Episode.from_dict(data) |
| 38 | + prog = loaded.programs[0] |
| 39 | + assert prog[0][1]['mapping'] == {int(k): int(v) for k, v in mapping.items()} |
| 40 | + assert np.array_equal(apply_program(inp, prog), out) |
0 commit comments