Skip to content

Commit 4df2dc8

Browse files
authored
Merge pull request #6 from tylerbessire/codex/fix-dsl-operation-signatures-and-integration-bugs
Fix DSL parameter grids and improve solver robustness
2 parents 5cd647f + e7549ec commit 4df2dc8

File tree

4 files changed

+67
-12
lines changed

4 files changed

+67
-12
lines changed

arc_solver/dsl.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,13 @@ def pad(a: Array, top: int, bottom: int, left: int, right: int, fill_value: int
190190
def apply_program(a: Array, program: List[Tuple[str, Dict[str, Any]]]) -> Array:
191191
"""Apply a sequence of operations to the input grid."""
192192
out = a
193-
for name, params in program:
194-
out = apply_op(out, name, params)
193+
for idx, (name, params) in enumerate(program):
194+
try:
195+
out = apply_op(out, name, params)
196+
except Exception as exc:
197+
raise ValueError(
198+
f"Failed to apply operation '{name}' at position {idx} with params {params}"
199+
) from exc
195200
return out
196201

197202

arc_solver/neural/guidance.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,17 @@ class NeuralGuidance:
144144
def __init__(self, model_path: Optional[str] = None):
145145
self.heuristic_guidance = HeuristicGuidance()
146146
self.neural_model = None
147-
147+
expected_dim = 17
148+
148149
if model_path and os.path.exists(model_path):
149-
self.load_model(model_path)
150+
try:
151+
self.load_model(model_path)
152+
if getattr(self.neural_model, "input_dim", expected_dim) != expected_dim:
153+
raise ValueError("model input dimension mismatch")
154+
except Exception:
155+
self.neural_model = SimpleClassifier(expected_dim)
150156
else:
151-
# For now, create a dummy neural model
152-
self.neural_model = SimpleClassifier(17) # 17 features
157+
self.neural_model = SimpleClassifier(expected_dim)
153158

154159
def predict_operations(self, train_pairs: List[Tuple[Array, Array]]) -> List[str]:
155160
"""Predict which operations are likely relevant for the task."""

arc_solver/neural/sketches.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,5 +209,26 @@ def generate_parameter_grid(operation: str, constraints: Dict[str, Any] = None)
209209
return [{'dy': dy, 'dx': dx} for dy in dy_range for dx in dx_range]
210210
elif operation == 'identity':
211211
return [{}]
212+
elif operation == 'crop':
213+
tops = range(3) if 'top' not in constraints else constraints['top']
214+
lefts = range(3) if 'left' not in constraints else constraints['left']
215+
heights = range(1, 4) if 'height' not in constraints else constraints['height']
216+
widths = range(1, 4) if 'width' not in constraints else constraints['width']
217+
return [
218+
{"top": t, "left": l, "height": h, "width": w}
219+
for t in tops for l in lefts for h in heights for w in widths
220+
]
221+
elif operation == 'pad':
222+
out_h = range(5, 20) if 'out_h' not in constraints else constraints['out_h']
223+
out_w = range(5, 20) if 'out_w' not in constraints else constraints['out_w']
224+
return [{"out_h": h, "out_w": w} for h in out_h for w in out_w]
225+
elif operation == 'recolor':
226+
mappings = []
227+
colors = range(10)
228+
for src in colors:
229+
for dst in colors:
230+
if src != dst:
231+
mappings.append({"mapping": {src: dst}})
232+
return mappings
212233
else:
213-
return [{}] # Default empty params for other operations
234+
return [{}] # Default empty params for unknown operations

arc_solver/search.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,36 @@ def enumerate_programs(depth: int = 2) -> List[List[Tuple[str, Dict[str, int]]]]
3232
for dy in range(-2, 3)
3333
for dx in range(-2, 3)
3434
],
35-
"recolor": [], # Recolor handled separately via heuristics
36-
"crop": [],
37-
"pad": [],
35+
"crop": [
36+
{"top": t, "left": l, "height": h, "width": w}
37+
for t in range(3)
38+
for l in range(3)
39+
for h in range(1, 4)
40+
for w in range(1, 4)
41+
],
42+
"pad": [
43+
{"out_h": h, "out_w": w}
44+
for h in range(5, 20)
45+
for w in range(5, 20)
46+
],
47+
"recolor": [
48+
{"mapping": {i: j}}
49+
for i in range(10)
50+
for j in range(10)
51+
if i != j
52+
],
3853
"identity": [{}],
3954
}
40-
base_ops = ["rotate", "flip", "transpose", "translate", "identity"]
55+
base_ops = [
56+
"rotate",
57+
"flip",
58+
"transpose",
59+
"translate",
60+
"crop",
61+
"pad",
62+
"recolor",
63+
"identity",
64+
]
4165
# Generate all length-1 programs
4266
if depth == 1:
4367
return [[(op, params)] for op in base_ops for params in param_space[op]]
@@ -108,4 +132,4 @@ def predict_two(
108132
except Exception:
109133
outs.append(ti)
110134
attempts.append(outs)
111-
return attempts
135+
return attempts

0 commit comments

Comments
 (0)