Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions arc_solver/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,13 @@ def pad(a: Array, top: int, bottom: int, left: int, right: int, fill_value: int
def apply_program(a: Array, program: List[Tuple[str, Dict[str, Any]]]) -> Array:
"""Apply a sequence of operations to the input grid."""
out = a
for name, params in program:
out = apply_op(out, name, params)
for idx, (name, params) in enumerate(program):
try:
out = apply_op(out, name, params)
except Exception as exc:
raise ValueError(
f"Failed to apply operation '{name}' at position {idx} with params {params}"
) from exc
return out


Expand Down
13 changes: 9 additions & 4 deletions arc_solver/neural/guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,17 @@ class NeuralGuidance:
def __init__(self, model_path: Optional[str] = None):
self.heuristic_guidance = HeuristicGuidance()
self.neural_model = None

expected_dim = 17

if model_path and os.path.exists(model_path):
self.load_model(model_path)
try:
self.load_model(model_path)
if getattr(self.neural_model, "input_dim", expected_dim) != expected_dim:
raise ValueError("model input dimension mismatch")
except Exception:
self.neural_model = SimpleClassifier(expected_dim)
else:
# For now, create a dummy neural model
self.neural_model = SimpleClassifier(17) # 17 features
self.neural_model = SimpleClassifier(expected_dim)

def predict_operations(self, train_pairs: List[Tuple[Array, Array]]) -> List[str]:
"""Predict which operations are likely relevant for the task."""
Expand Down
23 changes: 22 additions & 1 deletion arc_solver/neural/sketches.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,26 @@ def generate_parameter_grid(operation: str, constraints: Dict[str, Any] = None)
return [{'dy': dy, 'dx': dx} for dy in dy_range for dx in dx_range]
elif operation == 'identity':
return [{}]
elif operation == 'crop':
tops = range(3) if 'top' not in constraints else constraints['top']
lefts = range(3) if 'left' not in constraints else constraints['left']
heights = range(1, 4) if 'height' not in constraints else constraints['height']
widths = range(1, 4) if 'width' not in constraints else constraints['width']
return [
{"top": t, "left": l, "height": h, "width": w}
for t in tops for l in lefts for h in heights for w in widths
]
elif operation == 'pad':
out_h = range(5, 20) if 'out_h' not in constraints else constraints['out_h']
out_w = range(5, 20) if 'out_w' not in constraints else constraints['out_w']
return [{"out_h": h, "out_w": w} for h in out_h for w in out_w]
elif operation == 'recolor':
mappings = []
colors = range(10)
for src in colors:
for dst in colors:
if src != dst:
mappings.append({"mapping": {src: dst}})
return mappings
else:
return [{}] # Default empty params for other operations
return [{}] # Default empty params for unknown operations
34 changes: 29 additions & 5 deletions arc_solver/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,36 @@ def enumerate_programs(depth: int = 2) -> List[List[Tuple[str, Dict[str, int]]]]
for dy in range(-2, 3)
for dx in range(-2, 3)
],
"recolor": [], # Recolor handled separately via heuristics
"crop": [],
"pad": [],
"crop": [
{"top": t, "left": l, "height": h, "width": w}
for t in range(3)
for l in range(3)
for h in range(1, 4)
for w in range(1, 4)
],
"pad": [
{"out_h": h, "out_w": w}
for h in range(5, 20)
for w in range(5, 20)
],
"recolor": [
{"mapping": {i: j}}
for i in range(10)
for j in range(10)
if i != j
],
"identity": [{}],
}
base_ops = ["rotate", "flip", "transpose", "translate", "identity"]
base_ops = [
"rotate",
"flip",
"transpose",
"translate",
"crop",
"pad",
"recolor",
"identity",
]
# Generate all length-1 programs
if depth == 1:
return [[(op, params)] for op in base_ops for params in param_space[op]]
Expand Down Expand Up @@ -108,4 +132,4 @@ def predict_two(
except Exception:
outs.append(ti)
attempts.append(outs)
return attempts
return attempts
Loading