|
| 1 | +"""Hypothesis generation and testing for fluid intelligence in ARC tasks. |
| 2 | +
|
| 3 | +This module implements a lightweight hypothesis generation framework. It is |
| 4 | +not meant to be an exhaustive reasoning system but rather a scaffold that can |
| 5 | +be extended in later phases. The engine analyses training pairs and proposes |
| 6 | +plausible transformations together with simple confidence estimates. |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +from dataclasses import dataclass |
| 12 | +from typing import Any, Dict, List, Optional, Tuple |
| 13 | + |
| 14 | +import numpy as np |
| 15 | + |
| 16 | +from .grid import Array |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class Hypothesis: |
| 21 | + """Represents a hypothesis about task transformation.""" |
| 22 | + |
| 23 | + description: str |
| 24 | + transformation_type: str # "rotation", "color_swap", "pattern_fill", etc. |
| 25 | + confidence: float |
| 26 | + evidence: List[Dict[str, Any]] |
| 27 | + program_sketch: Optional[List[Tuple[str, Dict[str, Any]]]] = None |
| 28 | + |
| 29 | + |
| 30 | +class HypothesisEngine: |
| 31 | + """Generates and tests hypotheses about ARC task transformations.""" |
| 32 | + |
| 33 | + # ------------------------------------------------------------------ |
| 34 | + # Public API |
| 35 | + # ------------------------------------------------------------------ |
| 36 | + def generate_hypotheses(self, train_pairs: List[Tuple[Array, Array]]) -> List[Hypothesis]: |
| 37 | + """Generate multiple competing hypotheses about the task transformation.""" |
| 38 | + hypotheses: List[Hypothesis] = [] |
| 39 | + |
| 40 | + # 1. Geometric transformation hypotheses |
| 41 | + hypotheses.extend(self._generate_geometric_hypotheses(train_pairs)) |
| 42 | + |
| 43 | + # 2. Color transformation hypotheses |
| 44 | + hypotheses.extend(self._generate_color_hypotheses(train_pairs)) |
| 45 | + |
| 46 | + # 3. Pattern completion hypotheses |
| 47 | + hypotheses.extend(self._generate_pattern_hypotheses(train_pairs)) |
| 48 | + |
| 49 | + # 4. Object manipulation hypotheses |
| 50 | + hypotheses.extend(self._generate_object_hypotheses(train_pairs)) |
| 51 | + |
| 52 | + return sorted(hypotheses, key=lambda h: h.confidence, reverse=True) |
| 53 | + |
| 54 | + def test_hypothesis(self, hypothesis: Hypothesis, train_pairs: List[Tuple[Array, Array]]) -> float: |
| 55 | + """Test hypothesis validity against training data. |
| 56 | +
|
| 57 | + Returns a confidence score between 0 and 1 based on how many training |
| 58 | + pairs are perfectly explained by the hypothesis. |
| 59 | + """ |
| 60 | + if not train_pairs: |
| 61 | + return 0.0 |
| 62 | + matches = 0 |
| 63 | + for inp, out in train_pairs: |
| 64 | + pred = self.apply(hypothesis, inp) |
| 65 | + if pred is not None and pred.shape == out.shape and np.array_equal(pred, out): |
| 66 | + matches += 1 |
| 67 | + return matches / len(train_pairs) |
| 68 | + |
| 69 | + def refine_hypothesis(self, hypothesis: Hypothesis, feedback: Dict[str, Any]) -> Hypothesis: |
| 70 | + """Refine hypothesis based on test results. |
| 71 | +
|
| 72 | + A very small refinement mechanism is provided for now: the confidence is |
| 73 | + updated if feedback contains a ``confidence`` field and any additional |
| 74 | + evidence is appended to the hypothesis' evidence list. |
| 75 | + """ |
| 76 | + new_conf = float(feedback.get("confidence", hypothesis.confidence)) |
| 77 | + new_evidence = hypothesis.evidence + [feedback.get("evidence", {})] |
| 78 | + return Hypothesis( |
| 79 | + description=hypothesis.description, |
| 80 | + transformation_type=hypothesis.transformation_type, |
| 81 | + confidence=new_conf, |
| 82 | + evidence=new_evidence, |
| 83 | + program_sketch=hypothesis.program_sketch, |
| 84 | + ) |
| 85 | + |
| 86 | + # ------------------------------------------------------------------ |
| 87 | + # Internal helpers |
| 88 | + # ------------------------------------------------------------------ |
| 89 | + def apply(self, hypothesis: Hypothesis, grid: Array) -> Optional[Array]: |
| 90 | + """Apply the hypothesis to a grid, returning the transformed grid.""" |
| 91 | + try: |
| 92 | + if hypothesis.transformation_type == "rotation" and hypothesis.program_sketch: |
| 93 | + k = int(hypothesis.program_sketch[0][1].get("k", 0)) |
| 94 | + return np.rot90(grid, k) |
| 95 | + if hypothesis.transformation_type == "color_swap" and hypothesis.program_sketch: |
| 96 | + mapping = hypothesis.program_sketch[0][1].get("mapping", {}) |
| 97 | + result = grid.copy() |
| 98 | + for src, dst in mapping.items(): |
| 99 | + result[grid == src] = dst |
| 100 | + return result |
| 101 | + if hypothesis.transformation_type == "pattern_fill" and hypothesis.program_sketch: |
| 102 | + color = int(hypothesis.program_sketch[0][1].get("color", 0)) |
| 103 | + return np.full_like(grid, color) |
| 104 | + if hypothesis.transformation_type == "object_translation" and hypothesis.program_sketch: |
| 105 | + params = hypothesis.program_sketch[0][1] |
| 106 | + dy = int(params.get("dy", 0)) |
| 107 | + dx = int(params.get("dx", 0)) |
| 108 | + h, w = grid.shape |
| 109 | + result = np.zeros_like(grid) |
| 110 | + ys, xs = np.nonzero(grid) |
| 111 | + ys_new = ys + dy |
| 112 | + xs_new = xs + dx |
| 113 | + if ( |
| 114 | + (ys_new < 0).any() |
| 115 | + or (ys_new >= h).any() |
| 116 | + or (xs_new < 0).any() |
| 117 | + or (xs_new >= w).any() |
| 118 | + ): |
| 119 | + return None |
| 120 | + result[ys_new, xs_new] = grid[ys, xs] |
| 121 | + return result |
| 122 | + except Exception: |
| 123 | + return None |
| 124 | + return None |
| 125 | + |
| 126 | + # Hypothesis generation subroutines --------------------------------- |
| 127 | + def _generate_geometric_hypotheses(self, train_pairs: List[Tuple[Array, Array]]) -> List[Hypothesis]: |
| 128 | + hyps: List[Hypothesis] = [] |
| 129 | + rotations = [1, 2, 3] # 90, 180, 270 degrees |
| 130 | + for k in rotations: |
| 131 | + evidence: List[Dict[str, Any]] = [] |
| 132 | + matches = 0 |
| 133 | + for idx, (inp, out) in enumerate(train_pairs): |
| 134 | + rotated = np.rot90(inp, k) |
| 135 | + match = rotated.shape == out.shape and np.array_equal(rotated, out) |
| 136 | + evidence.append({"pair": idx, "rotation": k * 90, "match": match}) |
| 137 | + if match: |
| 138 | + matches += 1 |
| 139 | + confidence = matches / len(train_pairs) |
| 140 | + if confidence > 0: |
| 141 | + hyps.append( |
| 142 | + Hypothesis( |
| 143 | + description=f"Rotate input by {k * 90} degrees", |
| 144 | + transformation_type="rotation", |
| 145 | + confidence=confidence, |
| 146 | + evidence=evidence, |
| 147 | + program_sketch=[("rotate", {"k": k})], |
| 148 | + ) |
| 149 | + ) |
| 150 | + return hyps |
| 151 | + |
| 152 | + def _generate_color_hypotheses(self, train_pairs: List[Tuple[Array, Array]]) -> List[Hypothesis]: |
| 153 | + hyps: List[Hypothesis] = [] |
| 154 | + global_mapping: Dict[int, int] = {} |
| 155 | + evidence: List[Dict[str, Any]] = [] |
| 156 | + consistent = True |
| 157 | + for idx, (inp, out) in enumerate(train_pairs): |
| 158 | + mapping: Dict[int, int] = {} |
| 159 | + for ci, co in zip(inp.flat, out.flat): |
| 160 | + if ci in mapping and mapping[ci] != int(co): |
| 161 | + consistent = False |
| 162 | + break |
| 163 | + if ci != co: |
| 164 | + mapping[ci] = int(co) |
| 165 | + evidence.append({"pair": idx, "mapping": mapping}) |
| 166 | + for k, v in mapping.items(): |
| 167 | + if k in global_mapping and global_mapping[k] != v: |
| 168 | + consistent = False |
| 169 | + break |
| 170 | + global_mapping[k] = v |
| 171 | + if not consistent: |
| 172 | + break |
| 173 | + if consistent and global_mapping: |
| 174 | + hyps.append( |
| 175 | + Hypothesis( |
| 176 | + description=f"Recolor using mapping {global_mapping}", |
| 177 | + transformation_type="color_swap", |
| 178 | + confidence=1.0, |
| 179 | + evidence=evidence, |
| 180 | + program_sketch=[("recolor", {"mapping": global_mapping})], |
| 181 | + ) |
| 182 | + ) |
| 183 | + return hyps |
| 184 | + |
| 185 | + def _generate_pattern_hypotheses(self, train_pairs: List[Tuple[Array, Array]]) -> List[Hypothesis]: |
| 186 | + hyps: List[Hypothesis] = [] |
| 187 | + colors = [np.unique(out) for _, out in train_pairs] |
| 188 | + if colors and all(len(c) == 1 for c in colors): |
| 189 | + color = int(colors[0][0]) |
| 190 | + evidence = [{"pair": idx, "color": int(c[0])} for idx, c in enumerate(colors)] |
| 191 | + hyps.append( |
| 192 | + Hypothesis( |
| 193 | + description=f"Fill grid with color {color}", |
| 194 | + transformation_type="pattern_fill", |
| 195 | + confidence=1.0, |
| 196 | + evidence=evidence, |
| 197 | + program_sketch=[("fill", {"color": color})], |
| 198 | + ) |
| 199 | + ) |
| 200 | + return hyps |
| 201 | + |
| 202 | + def _find_translation(self, inp: Array, out: Array) -> Optional[Tuple[int, int]]: |
| 203 | + if inp.shape != out.shape: |
| 204 | + return None |
| 205 | + coords_in = np.argwhere(inp != 0) |
| 206 | + coords_out = np.argwhere(out != 0) |
| 207 | + if len(coords_in) == 0 or len(coords_out) == 0 or len(coords_in) != len(coords_out): |
| 208 | + return None |
| 209 | + shift = coords_out[0] - coords_in[0] |
| 210 | + h, w = inp.shape |
| 211 | + translated = np.zeros_like(inp) |
| 212 | + ys = coords_in[:, 0] + shift[0] |
| 213 | + xs = coords_in[:, 1] + shift[1] |
| 214 | + if (ys < 0).any() or (ys >= h).any() or (xs < 0).any() or (xs >= w).any(): |
| 215 | + return None |
| 216 | + translated[ys, xs] = inp[coords_in[:, 0], coords_in[:, 1]] |
| 217 | + if np.array_equal(translated, out): |
| 218 | + return int(shift[0]), int(shift[1]) |
| 219 | + return None |
| 220 | + |
| 221 | + def _generate_object_hypotheses(self, train_pairs: List[Tuple[Array, Array]]) -> List[Hypothesis]: |
| 222 | + hyps: List[Hypothesis] = [] |
| 223 | + shifts: List[Tuple[int, int]] = [] |
| 224 | + evidence: List[Dict[str, Any]] = [] |
| 225 | + for idx, (inp, out) in enumerate(train_pairs): |
| 226 | + trans = self._find_translation(inp, out) |
| 227 | + evidence.append({"pair": idx, "shift": trans}) |
| 228 | + if trans is not None: |
| 229 | + shifts.append(trans) |
| 230 | + if not shifts: |
| 231 | + return hyps |
| 232 | + common = shifts[0] |
| 233 | + if any(s != common for s in shifts): |
| 234 | + return hyps |
| 235 | + confidence = len(shifts) / len(train_pairs) |
| 236 | + hyps.append( |
| 237 | + Hypothesis( |
| 238 | + description=f"Translate object by {common}", |
| 239 | + transformation_type="object_translation", |
| 240 | + confidence=confidence, |
| 241 | + evidence=evidence, |
| 242 | + program_sketch=[("translate", {"dy": common[0], "dx": common[1]})], |
| 243 | + ) |
| 244 | + ) |
| 245 | + return hyps |
0 commit comments