Skip to content

Commit ea0ba32

Browse files
authored
Merge pull request #5 from tylerbessire/codex/update-arc-solver-files-for-sota
Add ARC solver evaluation pipeline and improve utilities
2 parents 1a232bb + 8c98177 commit ea0ba32

File tree

6 files changed

+233
-12
lines changed

6 files changed

+233
-12
lines changed

arc_solver/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
"""ARC Solver Package
1+
"""ARC Solver Package.
22
3-
Enhanced ARC solver with neural guidance, episodic retrieval, and test-time training.
3+
This package exposes the high-level :class:`ARCSolver` alongside common
4+
utilities for interacting with ARC datasets. The solver integrates neural
5+
guidance, episodic retrieval and test-time training into a cohesive system.
46
"""
7+
8+
from .solver import ARCSolver
9+
from .io_utils import load_rerun_json, save_submission
10+
from .grid import Array
11+
12+
__all__ = ["ARCSolver", "load_rerun_json", "save_submission", "Array"]

arc_solver/grid.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@
1616
# Type alias for clarity. ARC grids are small 2D arrays of integers.
1717
Array = np.ndarray
1818

19+
__all__ = [
20+
"Array",
21+
"to_array",
22+
"to_list",
23+
"same_shape",
24+
"rotate90",
25+
"flip",
26+
"transpose",
27+
"pad_to",
28+
"crop",
29+
"translate",
30+
"color_map",
31+
"histogram",
32+
"eq",
33+
"bg_color",
34+
]
35+
1936

2037
def to_array(grid: List[List[int]]) -> Array:
2138
"""Convert a nested Python list into a numpy array of dtype int16."""

arc_solver/heuristics.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,25 @@
1010

1111
from __future__ import annotations
1212

13+
import logging
1314
import numpy as np
1415
from typing import List, Dict, Tuple, Optional
1516

1617
from .grid import Array, eq, rotate90, flip, histogram, bg_color, to_array
1718
from .dsl import apply_program
1819

20+
logger = logging.getLogger(__name__)
21+
22+
__all__ = [
23+
"infer_color_mapping",
24+
"match_rotation_reflection",
25+
"infer_translation",
26+
"consistent_program_single_step",
27+
"guess_output_shape",
28+
"score_candidate",
29+
"diversify_programs",
30+
]
31+
1932

2033
def infer_color_mapping(inp: Array, out: Array) -> Optional[Dict[int, int]]:
2134
"""Try to infer a one-to-one color mapping between input and output grids.
@@ -123,8 +136,8 @@ def score_candidate(program: List[Tuple[str, Dict[str, int]]], train_pairs: List
123136
try:
124137
out = apply_program(a, program)
125138
good += int(eq(out, b))
126-
except Exception:
127-
pass
139+
except Exception as exc:
140+
logger.warning("Program execution failed on training pair: %s", exc)
128141
return good / len(train_pairs)
129142

130143

arc_solver/io_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
"/kaggle/input/arc-agi-2/arc-agi_test_challenges.json",
2323
]
2424

25+
__all__ = ["load_rerun_json", "save_submission"]
26+
2527

2628
def load_rerun_json() -> Dict[str, Any]:
2729
"""Load the JSON file containing all test tasks for the competition.

arc_solver/ttt.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99

1010
from __future__ import annotations
1111

12+
import logging
1213
import numpy as np
1314
from typing import List, Tuple, Dict, Any, Optional
1415
from copy import deepcopy
1516

1617
from .grid import Array, eq
1718
from .dsl import apply_program
1819

20+
logger = logging.getLogger(__name__)
21+
22+
__all__ = ["AdaptiveScorer", "TestTimeTrainer", "DataAugmentation"]
23+
1924

2025
class AdaptiveScorer:
2126
"""Adaptive scoring function that can be fine-tuned at test time."""
@@ -51,8 +56,11 @@ def extract_program_features(self, program: List[Tuple[str, Dict[str, Any]]],
5156
# Compute partial match (e.g., correct shape)
5257
if pred_out.shape == target_out.shape:
5358
partial_matches += 1
54-
except Exception:
55-
pass
59+
except Exception as exc:
60+
logger.warning(
61+
"Program execution failed during feature extraction: %s", exc
62+
)
63+
continue
5664

5765
features[2] = exact_matches / len(train_pairs)
5866
features[3] = partial_matches / len(train_pairs)
@@ -167,8 +175,11 @@ def _evaluate_program(self, program: List[Tuple[str, Dict[str, Any]]],
167175
pred_out = apply_program(inp, program)
168176
if eq(pred_out, target_out):
169177
successes += 1
170-
except Exception:
171-
pass
178+
except Exception as exc:
179+
logger.warning(
180+
"Program evaluation failed during adaptation: %s", exc
181+
)
182+
continue
172183

173184
return successes / len(train_pairs) if train_pairs else 0.0
174185

@@ -216,8 +227,11 @@ def augment_training_pairs(train_pairs: List[Tuple[Array, Array]],
216227
aug_inp = np.rot90(inp, k)
217228
aug_out = np.rot90(out, k)
218229
augmented.append((aug_inp, aug_out))
219-
except Exception:
220-
pass
230+
except Exception as exc:
231+
logger.warning(
232+
"Rotation augmentation failed (k=%s): %s", k, exc
233+
)
234+
continue
221235

222236
# Try reflections
223237
for axis in [0, 1]:
@@ -227,8 +241,11 @@ def augment_training_pairs(train_pairs: List[Tuple[Array, Array]],
227241
aug_inp = np.flip(inp, axis=axis)
228242
aug_out = np.flip(out, axis=axis)
229243
augmented.append((aug_inp, aug_out))
230-
except Exception:
231-
pass
244+
except Exception as exc:
245+
logger.warning(
246+
"Reflection augmentation failed (axis=%s): %s", axis, exc
247+
)
248+
continue
232249

233250
return augmented[:max_augmentations]
234251

tools/colab_eval.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Train and evaluate the ARC solver in Kaggle/Colab environments.
2+
3+
This script provides a minimal end-to-end pipeline for training the neural
4+
guidance classifier and producing Kaggle-compatible submission files. When
5+
ground-truth solutions are provided, it also reports accuracy and per-task
6+
differences between predictions and targets.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import argparse
12+
import json
13+
import sys
14+
from pathlib import Path
15+
from typing import Any, Dict, List, Optional, Tuple
16+
17+
import numpy as np
18+
19+
# Ensure repository root is on the path so arc_solver can be imported when this
20+
# script runs in Kaggle/Colab notebooks.
21+
sys.path.append(str(Path(__file__).parent.parent))
22+
23+
from arc_solver.solver import ARCSolver
24+
from arc_solver.grid import to_array, eq
25+
from arc_solver.io_utils import save_submission
26+
from train_guidance import (
27+
load_training_data,
28+
extract_training_features_and_labels,
29+
train_classifier,
30+
save_classifier,
31+
)
32+
33+
34+
def train_guidance_model(
35+
train_json: str,
36+
solutions_json: Optional[str],
37+
model_path: str,
38+
epochs: int = 100,
39+
) -> str:
40+
"""Train the neural guidance classifier.
41+
42+
Args:
43+
train_json: Path to the ARC training challenges JSON.
44+
solutions_json: Optional path to training solutions for supervised labels.
45+
model_path: Where to persist the trained classifier.
46+
epochs: Number of training epochs.
47+
48+
Returns:
49+
Path to the saved model.
50+
"""
51+
tasks = load_training_data(train_json, solutions_json)
52+
features, labels = extract_training_features_and_labels(tasks)
53+
classifier = train_classifier(features, labels, epochs)
54+
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
55+
save_classifier(classifier, model_path)
56+
return model_path
57+
58+
59+
def evaluate_solver(
60+
test_json: str,
61+
model_path: str,
62+
solutions_json: Optional[str],
63+
out_path: str,
64+
) -> Tuple[float, Dict[str, List[List[List[int]]]]]:
65+
"""Run the solver on evaluation tasks and optionally score against solutions.
66+
67+
Args:
68+
test_json: Path to evaluation challenges JSON.
69+
model_path: Path to trained guidance model.
70+
solutions_json: Optional path to ground-truth solutions for scoring.
71+
out_path: Where to write the Kaggle submission JSON.
72+
73+
Returns:
74+
Tuple of overall accuracy (0-1) and a mapping of task ids to diff grids.
75+
"""
76+
solver = ARCSolver(use_enhancements=True, guidance_model_path=model_path)
77+
78+
with open(test_json, "r") as f:
79+
test_tasks: Dict[str, Any] = json.load(f)
80+
81+
solutions: Dict[str, Any] = {}
82+
if solutions_json and Path(solutions_json).exists():
83+
with open(solutions_json, "r") as f:
84+
solutions = json.load(f)
85+
86+
predictions: Dict[str, Dict[str, List[List[List[int]]]]] = {}
87+
diffs: Dict[str, List[List[List[int]]]] = {}
88+
correct = 0
89+
total = 0
90+
91+
for task_id, task in test_tasks.items():
92+
result = solver.solve_task(task)
93+
predictions[task_id] = result
94+
95+
if task_id in solutions:
96+
target_grids = [pair["output"] for pair in solutions[task_id]["test"]]
97+
pred_grids = result["attempt_1"]
98+
diff_grids: List[List[List[int]]] = []
99+
all_match = True
100+
101+
for pred, target in zip(pred_grids, target_grids):
102+
pa = to_array(pred)
103+
ta = to_array(target)
104+
all_match &= eq(pa, ta)
105+
diff_grids.append((pa != ta).astype(int).tolist())
106+
107+
if all_match:
108+
correct += 1
109+
diffs[task_id] = diff_grids
110+
total += 1
111+
112+
save_submission(predictions, out_path)
113+
accuracy = correct / total if total else 0.0
114+
return accuracy, diffs
115+
116+
117+
def main() -> None:
118+
parser = argparse.ArgumentParser(description="Train and evaluate ARC solver")
119+
parser.add_argument("--train-json", help="Path to training challenges JSON")
120+
parser.add_argument(
121+
"--train-solutions", help="Path to training solutions JSON", default=None
122+
)
123+
parser.add_argument(
124+
"--model-path",
125+
default="neural_guidance_model.json",
126+
help="Where to save or load the guidance model",
127+
)
128+
parser.add_argument("--test-json", required=True, help="Path to evaluation JSON")
129+
parser.add_argument(
130+
"--test-solutions",
131+
help="Path to evaluation solutions JSON for scoring",
132+
default=None,
133+
)
134+
parser.add_argument(
135+
"--out", default="submission.json", help="Output path for submission JSON"
136+
)
137+
parser.add_argument("--epochs", type=int, default=100, help="Training epochs")
138+
139+
args = parser.parse_args()
140+
141+
if args.train_json:
142+
train_guidance_model(
143+
args.train_json, args.train_solutions, args.model_path, args.epochs
144+
)
145+
146+
accuracy, diffs = evaluate_solver(
147+
args.test_json, args.model_path, args.test_solutions, args.out
148+
)
149+
150+
if args.test_solutions:
151+
print(f"Accuracy: {accuracy * 100:.2f}%")
152+
for task_id, diff in diffs.items():
153+
if any(np.any(np.array(d)) for d in diff):
154+
status = "incorrect"
155+
else:
156+
status = "correct"
157+
print(f"Task {task_id}: {status}")
158+
159+
print(f"Submission file written to {args.out}")
160+
161+
162+
if __name__ == "__main__":
163+
main()
164+

0 commit comments

Comments
 (0)