diff --git a/arc_solver/dsl.py b/arc_solver/dsl.py index 7c3e80e..57ca245 100644 --- a/arc_solver/dsl.py +++ b/arc_solver/dsl.py @@ -190,8 +190,13 @@ def pad(a: Array, top: int, bottom: int, left: int, right: int, fill_value: int def apply_program(a: Array, program: List[Tuple[str, Dict[str, Any]]]) -> Array: """Apply a sequence of operations to the input grid.""" out = a - for name, params in program: - out = apply_op(out, name, params) + for idx, (name, params) in enumerate(program): + try: + out = apply_op(out, name, params) + except Exception as exc: + raise ValueError( + f"Failed to apply operation '{name}' at position {idx} with params {params}" + ) from exc return out diff --git a/arc_solver/neural/guidance.py b/arc_solver/neural/guidance.py index b1a37d8..9d33c41 100644 --- a/arc_solver/neural/guidance.py +++ b/arc_solver/neural/guidance.py @@ -144,12 +144,17 @@ class NeuralGuidance: def __init__(self, model_path: Optional[str] = None): self.heuristic_guidance = HeuristicGuidance() self.neural_model = None - + expected_dim = 17 + if model_path and os.path.exists(model_path): - self.load_model(model_path) + try: + self.load_model(model_path) + if getattr(self.neural_model, "input_dim", expected_dim) != expected_dim: + raise ValueError("model input dimension mismatch") + except Exception: + self.neural_model = SimpleClassifier(expected_dim) else: - # For now, create a dummy neural model - self.neural_model = SimpleClassifier(17) # 17 features + self.neural_model = SimpleClassifier(expected_dim) def predict_operations(self, train_pairs: List[Tuple[Array, Array]]) -> List[str]: """Predict which operations are likely relevant for the task.""" diff --git a/arc_solver/neural/sketches.py b/arc_solver/neural/sketches.py index f752b13..48c50cb 100644 --- a/arc_solver/neural/sketches.py +++ b/arc_solver/neural/sketches.py @@ -209,5 +209,26 @@ def generate_parameter_grid(operation: str, constraints: Dict[str, Any] = None) return [{'dy': dy, 'dx': dx} for dy in dy_range for dx in dx_range] elif operation == 'identity': return [{}] + elif operation == 'crop': + tops = range(3) if 'top' not in constraints else constraints['top'] + lefts = range(3) if 'left' not in constraints else constraints['left'] + heights = range(1, 4) if 'height' not in constraints else constraints['height'] + widths = range(1, 4) if 'width' not in constraints else constraints['width'] + return [ + {"top": t, "left": l, "height": h, "width": w} + for t in tops for l in lefts for h in heights for w in widths + ] + elif operation == 'pad': + out_h = range(5, 20) if 'out_h' not in constraints else constraints['out_h'] + out_w = range(5, 20) if 'out_w' not in constraints else constraints['out_w'] + return [{"out_h": h, "out_w": w} for h in out_h for w in out_w] + elif operation == 'recolor': + mappings = [] + colors = range(10) + for src in colors: + for dst in colors: + if src != dst: + mappings.append({"mapping": {src: dst}}) + return mappings else: - return [{}] # Default empty params for other operations + return [{}] # Default empty params for unknown operations diff --git a/arc_solver/search.py b/arc_solver/search.py index 80d62ba..23cbf73 100644 --- a/arc_solver/search.py +++ b/arc_solver/search.py @@ -32,12 +32,36 @@ def enumerate_programs(depth: int = 2) -> List[List[Tuple[str, Dict[str, int]]]] for dy in range(-2, 3) for dx in range(-2, 3) ], - "recolor": [], # Recolor handled separately via heuristics - "crop": [], - "pad": [], + "crop": [ + {"top": t, "left": l, "height": h, "width": w} + for t in range(3) + for l in range(3) + for h in range(1, 4) + for w in range(1, 4) + ], + "pad": [ + {"out_h": h, "out_w": w} + for h in range(5, 20) + for w in range(5, 20) + ], + "recolor": [ + {"mapping": {i: j}} + for i in range(10) + for j in range(10) + if i != j + ], "identity": [{}], } - base_ops = ["rotate", "flip", "transpose", "translate", "identity"] + base_ops = [ + "rotate", + "flip", + "transpose", + "translate", + "crop", + "pad", + "recolor", + "identity", + ] # Generate all length-1 programs if depth == 1: return [[(op, params)] for op in base_ops for params in param_space[op]] @@ -108,4 +132,4 @@ def predict_two( except Exception: outs.append(ti) attempts.append(outs) - return attempts \ No newline at end of file + return attempts