Skip to content

Commit 2e2f8d8

Browse files
authored
Merge pull request #3 from tylerbessire/codex/continue-implementation-for-sota-model
Implement numerical features and episodic memory
2 parents ebbecb0 + 7e33527 commit 2e2f8d8

File tree

4 files changed

+484
-244
lines changed

4 files changed

+484
-244
lines changed

arc_solver/features.py

Lines changed: 114 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,51 +18,57 @@
1818

1919
def extract_task_features(train_pairs: List[Tuple[Array, Array]]) -> Dict[str, Any]:
2020
"""Extract a comprehensive feature vector from training pairs.
21-
21+
2222
These features capture task-level properties that can help predict which
2323
DSL operations are likely to be relevant for solving the task.
2424
"""
25+
26+
# Ensure arrays are integer typed for canonicalisation
2527
try:
26-
train_pairs = [
27-
(canonicalize_D4(inp), canonicalize_D4(out))
28+
original_pairs = [
29+
(np.asarray(inp, dtype=int), np.asarray(out, dtype=int))
2830
for inp, out in train_pairs
2931
]
30-
except TypeError as exc:
32+
canonical_pairs = [
33+
(canonicalize_D4(inp), canonicalize_D4(out))
34+
for inp, out in original_pairs
35+
]
36+
except Exception as exc:
3137
raise ValueError(f"invalid grid in train_pairs: {exc}") from exc
3238

3339
features: Dict[str, Any] = {}
34-
35-
# Basic grid statistics
36-
input_shapes = [inp.shape for inp, _ in train_pairs]
37-
output_shapes = [out.shape for _, out in train_pairs]
38-
40+
41+
# Basic grid statistics using original shapes
42+
input_shapes = [inp.shape for inp, _ in original_pairs]
43+
output_shapes = [out.shape for _, out in original_pairs]
44+
3945
features.update({
40-
'num_train_pairs': len(train_pairs),
46+
'num_train_pairs': len(original_pairs),
4147
'input_height_mean': np.mean([s[0] for s in input_shapes]),
4248
'input_width_mean': np.mean([s[1] for s in input_shapes]),
4349
'output_height_mean': np.mean([s[0] for s in output_shapes]),
4450
'output_width_mean': np.mean([s[1] for s in output_shapes]),
45-
'shape_preserved': all(inp.shape == out.shape for inp, out in train_pairs),
51+
'shape_preserved': all(inp.shape == out.shape for inp, out in original_pairs),
4652
'size_ratio_mean': np.mean([
4753
(out.shape[0] * out.shape[1]) / (inp.shape[0] * inp.shape[1])
48-
for inp, out in train_pairs
54+
for inp, out in original_pairs
4955
]),
5056
})
51-
52-
# Color analysis
53-
input_colors = []
54-
output_colors = []
55-
color_mappings = []
56-
57-
for inp, out in train_pairs:
57+
58+
# Color analysis on canonical pairs
59+
input_colors: List[int] = []
60+
output_colors: List[int] = []
61+
color_mappings: List[int] = []
62+
63+
for inp, out in canonical_pairs:
5864
inp_hist = histogram(inp)
5965
out_hist = histogram(out)
6066
input_colors.append(len(inp_hist))
6167
output_colors.append(len(out_hist))
62-
68+
6369
# Try to detect color mappings
6470
if inp.shape == out.shape:
65-
mapping = {}
71+
mapping: Dict[int, int] = {}
6672
valid_mapping = True
6773
for i_val, o_val in zip(inp.flatten(), out.flatten()):
6874
if i_val in mapping and mapping[i_val] != o_val:
@@ -71,49 +77,51 @@ def extract_task_features(train_pairs: List[Tuple[Array, Array]]) -> Dict[str, A
7177
mapping[i_val] = o_val
7278
if valid_mapping:
7379
color_mappings.append(len(mapping))
74-
80+
7581
features.update({
7682
'input_colors_mean': np.mean(input_colors),
7783
'output_colors_mean': np.mean(output_colors),
78-
'background_color_consistent': len(set(bg_color(inp) for inp, _ in train_pairs)) == 1,
84+
'background_color_consistent': len(set(bg_color(inp) for inp, _ in canonical_pairs)) == 1,
7985
'has_color_mapping': len(color_mappings) > 0,
8086
'color_mapping_size': np.mean(color_mappings) if color_mappings else 0,
8187
})
82-
83-
# Object analysis
84-
input_obj_counts = []
85-
output_obj_counts = []
86-
87-
for inp, out in train_pairs:
88+
89+
# Object analysis on canonical pairs
90+
input_obj_counts: List[int] = []
91+
output_obj_counts: List[int] = []
92+
93+
for inp, out in canonical_pairs:
8894
inp_objects = connected_components(inp)
8995
out_objects = connected_components(out)
9096
input_obj_counts.append(len(inp_objects))
9197
output_obj_counts.append(len(out_objects))
92-
98+
9399
features.update({
94100
'input_objects_mean': np.mean(input_obj_counts),
95101
'output_objects_mean': np.mean(output_obj_counts),
96-
'object_count_preserved': np.mean([
102+
'object_count_preserved': all(
97103
len(connected_components(inp)) == len(connected_components(out))
98-
for inp, out in train_pairs
99-
]),
104+
for inp, out in canonical_pairs
105+
),
100106
})
101-
102-
# Transformation hints
107+
108+
# Transformation hints from original pairs
103109
features.update({
104-
'likely_rotation': _detect_rotation_patterns(train_pairs),
105-
'likely_reflection': _detect_reflection_patterns(train_pairs),
106-
'likely_translation': _detect_translation_patterns(train_pairs),
107-
'likely_recolor': _detect_recolor_patterns(train_pairs),
108-
'likely_crop': _detect_crop_patterns(train_pairs),
109-
'likely_pad': _detect_pad_patterns(train_pairs),
110+
'likely_rotation': _detect_rotation_patterns(original_pairs),
111+
'likely_reflection': _detect_reflection_patterns(original_pairs),
112+
'likely_translation': _detect_translation_patterns(original_pairs),
113+
'likely_recolor': _detect_recolor_patterns(original_pairs),
114+
'likely_crop': _detect_crop_patterns(original_pairs),
115+
'likely_pad': _detect_pad_patterns(original_pairs),
110116
})
111-
117+
112118
return features
113119

114120

115121
def _detect_rotation_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
116122
"""Detect if rotation transformations are likely."""
123+
if not train_pairs:
124+
return 0.0
117125
rotation_score = 0.0
118126
for inp, out in train_pairs:
119127
if inp.shape[0] == inp.shape[1] and out.shape[0] == out.shape[1]:
@@ -127,6 +135,8 @@ def _detect_rotation_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
127135

128136
def _detect_reflection_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
129137
"""Detect if reflection transformations are likely."""
138+
if not train_pairs:
139+
return 0.0
130140
reflection_score = 0.0
131141
for inp, out in train_pairs:
132142
if inp.shape == out.shape:
@@ -139,7 +149,7 @@ def _detect_reflection_patterns(train_pairs: List[Tuple[Array, Array]]) -> float
139149

140150
def _detect_translation_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
141151
"""Detect if translation transformations are likely."""
142-
if not all(inp.shape == out.shape for inp, out in train_pairs):
152+
if not train_pairs or not all(inp.shape == out.shape for inp, out in train_pairs):
143153
return 0.0
144154

145155
translation_score = 0.0
@@ -156,6 +166,8 @@ def _detect_translation_patterns(train_pairs: List[Tuple[Array, Array]]) -> floa
156166

157167
def _detect_recolor_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
158168
"""Detect if recoloring transformations are likely."""
169+
if not train_pairs:
170+
return 0.0
159171
recolor_score = 0.0
160172
for inp, out in train_pairs:
161173
if inp.shape == out.shape:
@@ -174,6 +186,8 @@ def _detect_recolor_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
174186

175187
def _detect_crop_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
176188
"""Detect if cropping transformations are likely."""
189+
if not train_pairs:
190+
return 0.0
177191
crop_score = 0.0
178192
for inp, out in train_pairs:
179193
if (inp.shape[0] > out.shape[0] or inp.shape[1] > out.shape[1]):
@@ -183,6 +197,8 @@ def _detect_crop_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
183197

184198
def _detect_pad_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
185199
"""Detect if padding transformations are likely."""
200+
if not train_pairs:
201+
return 0.0
186202
pad_score = 0.0
187203
for inp, out in train_pairs:
188204
if (inp.shape[0] < out.shape[0] or inp.shape[1] < out.shape[1]):
@@ -225,3 +241,58 @@ def _operation_hints(features: Dict[str, Any]) -> str:
225241
hints.append('P')
226242

227243
return "".join(hints) if hints else "U" # U for unknown
244+
245+
246+
def compute_numerical_features(train_pairs: List[Tuple[Array, Array]]) -> np.ndarray:
247+
"""Convert task features to a numerical vector.
248+
249+
This utility is primarily used by learning components that expect a fixed
250+
numeric representation. The order of features is deterministic to ensure
251+
reproducibility across runs.
252+
253+
Args:
254+
train_pairs: List of training input/output grid pairs.
255+
256+
Returns:
257+
A 1-D numpy array of feature values. Boolean features are encoded as
258+
``0.0`` or ``1.0``.
259+
"""
260+
261+
features = extract_task_features(train_pairs)
262+
263+
numerical_keys = [
264+
'num_train_pairs',
265+
'input_height_mean',
266+
'input_width_mean',
267+
'output_height_mean',
268+
'output_width_mean',
269+
'shape_preserved',
270+
'size_ratio_mean',
271+
'input_colors_mean',
272+
'output_colors_mean',
273+
'background_color_consistent',
274+
'has_color_mapping',
275+
'color_mapping_size',
276+
'input_objects_mean',
277+
'output_objects_mean',
278+
'object_count_preserved',
279+
'likely_rotation',
280+
'likely_reflection',
281+
'likely_translation',
282+
'likely_recolor',
283+
'likely_crop',
284+
'likely_pad',
285+
]
286+
287+
values: List[float] = []
288+
for key in numerical_keys:
289+
val = features.get(key, 0)
290+
if isinstance(val, bool):
291+
values.append(1.0 if val else 0.0)
292+
else:
293+
try:
294+
values.append(float(val))
295+
except (TypeError, ValueError): # pragma: no cover - defensive path
296+
values.append(0.0)
297+
298+
return np.array(values, dtype=float)

arc_solver/guidance.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,22 @@ def __init__(self, input_dim: int, hidden_dim: int = 32):
3838

3939
def forward(self, x: np.ndarray) -> np.ndarray:
4040
"""Forward pass through the network."""
41+
if x.ndim == 1:
42+
x = x.reshape(1, -1)
4143
# First layer
4244
h = np.maximum(0, np.dot(x, self.weights1) + self.bias1) # ReLU
4345
# Output layer with sigmoid
4446
out = 1.0 / (1.0 + np.exp(-(np.dot(h, self.weights2) + self.bias2)))
45-
return out
47+
return out.squeeze()
4648

4749
def predict_operations(self, features: Dict[str, Any], threshold: float = 0.5) -> List[str]:
4850
"""Predict which operations are likely relevant."""
4951
feature_vector = self._features_to_vector(features)
50-
probabilities = self.forward(feature_vector)
51-
52+
probabilities = self.forward(feature_vector).ravel()
53+
5254
relevant_ops = []
5355
for i, prob in enumerate(probabilities):
54-
if prob > threshold:
56+
if float(prob) > threshold:
5557
relevant_ops.append(self.operations[i])
5658

5759
return relevant_ops if relevant_ops else ['identity']

0 commit comments

Comments
 (0)