1818
1919def extract_task_features (train_pairs : List [Tuple [Array , Array ]]) -> Dict [str , Any ]:
2020 """Extract a comprehensive feature vector from training pairs.
21-
21+
2222 These features capture task-level properties that can help predict which
2323 DSL operations are likely to be relevant for solving the task.
2424 """
25+
26+ # Ensure arrays are integer typed for canonicalisation
2527 try :
26- train_pairs = [
27- (canonicalize_D4 (inp ), canonicalize_D4 (out ))
28+ original_pairs = [
29+ (np . asarray (inp , dtype = int ), np . asarray (out , dtype = int ))
2830 for inp , out in train_pairs
2931 ]
30- except TypeError as exc :
32+ canonical_pairs = [
33+ (canonicalize_D4 (inp ), canonicalize_D4 (out ))
34+ for inp , out in original_pairs
35+ ]
36+ except Exception as exc :
3137 raise ValueError (f"invalid grid in train_pairs: { exc } " ) from exc
3238
3339 features : Dict [str , Any ] = {}
34-
35- # Basic grid statistics
36- input_shapes = [inp .shape for inp , _ in train_pairs ]
37- output_shapes = [out .shape for _ , out in train_pairs ]
38-
40+
41+ # Basic grid statistics using original shapes
42+ input_shapes = [inp .shape for inp , _ in original_pairs ]
43+ output_shapes = [out .shape for _ , out in original_pairs ]
44+
3945 features .update ({
40- 'num_train_pairs' : len (train_pairs ),
46+ 'num_train_pairs' : len (original_pairs ),
4147 'input_height_mean' : np .mean ([s [0 ] for s in input_shapes ]),
4248 'input_width_mean' : np .mean ([s [1 ] for s in input_shapes ]),
4349 'output_height_mean' : np .mean ([s [0 ] for s in output_shapes ]),
4450 'output_width_mean' : np .mean ([s [1 ] for s in output_shapes ]),
45- 'shape_preserved' : all (inp .shape == out .shape for inp , out in train_pairs ),
51+ 'shape_preserved' : all (inp .shape == out .shape for inp , out in original_pairs ),
4652 'size_ratio_mean' : np .mean ([
4753 (out .shape [0 ] * out .shape [1 ]) / (inp .shape [0 ] * inp .shape [1 ])
48- for inp , out in train_pairs
54+ for inp , out in original_pairs
4955 ]),
5056 })
51-
52- # Color analysis
53- input_colors = []
54- output_colors = []
55- color_mappings = []
56-
57- for inp , out in train_pairs :
57+
58+ # Color analysis on canonical pairs
59+ input_colors : List [ int ] = []
60+ output_colors : List [ int ] = []
61+ color_mappings : List [ int ] = []
62+
63+ for inp , out in canonical_pairs :
5864 inp_hist = histogram (inp )
5965 out_hist = histogram (out )
6066 input_colors .append (len (inp_hist ))
6167 output_colors .append (len (out_hist ))
62-
68+
6369 # Try to detect color mappings
6470 if inp .shape == out .shape :
65- mapping = {}
71+ mapping : Dict [ int , int ] = {}
6672 valid_mapping = True
6773 for i_val , o_val in zip (inp .flatten (), out .flatten ()):
6874 if i_val in mapping and mapping [i_val ] != o_val :
@@ -71,49 +77,51 @@ def extract_task_features(train_pairs: List[Tuple[Array, Array]]) -> Dict[str, A
7177 mapping [i_val ] = o_val
7278 if valid_mapping :
7379 color_mappings .append (len (mapping ))
74-
80+
7581 features .update ({
7682 'input_colors_mean' : np .mean (input_colors ),
7783 'output_colors_mean' : np .mean (output_colors ),
78- 'background_color_consistent' : len (set (bg_color (inp ) for inp , _ in train_pairs )) == 1 ,
84+ 'background_color_consistent' : len (set (bg_color (inp ) for inp , _ in canonical_pairs )) == 1 ,
7985 'has_color_mapping' : len (color_mappings ) > 0 ,
8086 'color_mapping_size' : np .mean (color_mappings ) if color_mappings else 0 ,
8187 })
82-
83- # Object analysis
84- input_obj_counts = []
85- output_obj_counts = []
86-
87- for inp , out in train_pairs :
88+
89+ # Object analysis on canonical pairs
90+ input_obj_counts : List [ int ] = []
91+ output_obj_counts : List [ int ] = []
92+
93+ for inp , out in canonical_pairs :
8894 inp_objects = connected_components (inp )
8995 out_objects = connected_components (out )
9096 input_obj_counts .append (len (inp_objects ))
9197 output_obj_counts .append (len (out_objects ))
92-
98+
9399 features .update ({
94100 'input_objects_mean' : np .mean (input_obj_counts ),
95101 'output_objects_mean' : np .mean (output_obj_counts ),
96- 'object_count_preserved' : np . mean ([
102+ 'object_count_preserved' : all (
97103 len (connected_components (inp )) == len (connected_components (out ))
98- for inp , out in train_pairs
99- ] ),
104+ for inp , out in canonical_pairs
105+ ),
100106 })
101-
102- # Transformation hints
107+
108+ # Transformation hints from original pairs
103109 features .update ({
104- 'likely_rotation' : _detect_rotation_patterns (train_pairs ),
105- 'likely_reflection' : _detect_reflection_patterns (train_pairs ),
106- 'likely_translation' : _detect_translation_patterns (train_pairs ),
107- 'likely_recolor' : _detect_recolor_patterns (train_pairs ),
108- 'likely_crop' : _detect_crop_patterns (train_pairs ),
109- 'likely_pad' : _detect_pad_patterns (train_pairs ),
110+ 'likely_rotation' : _detect_rotation_patterns (original_pairs ),
111+ 'likely_reflection' : _detect_reflection_patterns (original_pairs ),
112+ 'likely_translation' : _detect_translation_patterns (original_pairs ),
113+ 'likely_recolor' : _detect_recolor_patterns (original_pairs ),
114+ 'likely_crop' : _detect_crop_patterns (original_pairs ),
115+ 'likely_pad' : _detect_pad_patterns (original_pairs ),
110116 })
111-
117+
112118 return features
113119
114120
115121def _detect_rotation_patterns (train_pairs : List [Tuple [Array , Array ]]) -> float :
116122 """Detect if rotation transformations are likely."""
123+ if not train_pairs :
124+ return 0.0
117125 rotation_score = 0.0
118126 for inp , out in train_pairs :
119127 if inp .shape [0 ] == inp .shape [1 ] and out .shape [0 ] == out .shape [1 ]:
@@ -127,6 +135,8 @@ def _detect_rotation_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
127135
128136def _detect_reflection_patterns (train_pairs : List [Tuple [Array , Array ]]) -> float :
129137 """Detect if reflection transformations are likely."""
138+ if not train_pairs :
139+ return 0.0
130140 reflection_score = 0.0
131141 for inp , out in train_pairs :
132142 if inp .shape == out .shape :
@@ -139,7 +149,7 @@ def _detect_reflection_patterns(train_pairs: List[Tuple[Array, Array]]) -> float
139149
140150def _detect_translation_patterns (train_pairs : List [Tuple [Array , Array ]]) -> float :
141151 """Detect if translation transformations are likely."""
142- if not all (inp .shape == out .shape for inp , out in train_pairs ):
152+ if not train_pairs or not all (inp .shape == out .shape for inp , out in train_pairs ):
143153 return 0.0
144154
145155 translation_score = 0.0
@@ -156,6 +166,8 @@ def _detect_translation_patterns(train_pairs: List[Tuple[Array, Array]]) -> floa
156166
157167def _detect_recolor_patterns (train_pairs : List [Tuple [Array , Array ]]) -> float :
158168 """Detect if recoloring transformations are likely."""
169+ if not train_pairs :
170+ return 0.0
159171 recolor_score = 0.0
160172 for inp , out in train_pairs :
161173 if inp .shape == out .shape :
@@ -174,6 +186,8 @@ def _detect_recolor_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
174186
175187def _detect_crop_patterns (train_pairs : List [Tuple [Array , Array ]]) -> float :
176188 """Detect if cropping transformations are likely."""
189+ if not train_pairs :
190+ return 0.0
177191 crop_score = 0.0
178192 for inp , out in train_pairs :
179193 if (inp .shape [0 ] > out .shape [0 ] or inp .shape [1 ] > out .shape [1 ]):
@@ -183,6 +197,8 @@ def _detect_crop_patterns(train_pairs: List[Tuple[Array, Array]]) -> float:
183197
184198def _detect_pad_patterns (train_pairs : List [Tuple [Array , Array ]]) -> float :
185199 """Detect if padding transformations are likely."""
200+ if not train_pairs :
201+ return 0.0
186202 pad_score = 0.0
187203 for inp , out in train_pairs :
188204 if (inp .shape [0 ] < out .shape [0 ] or inp .shape [1 ] < out .shape [1 ]):
@@ -225,3 +241,58 @@ def _operation_hints(features: Dict[str, Any]) -> str:
225241 hints .append ('P' )
226242
227243 return "" .join (hints ) if hints else "U" # U for unknown
244+
245+
246+ def compute_numerical_features (train_pairs : List [Tuple [Array , Array ]]) -> np .ndarray :
247+ """Convert task features to a numerical vector.
248+
249+ This utility is primarily used by learning components that expect a fixed
250+ numeric representation. The order of features is deterministic to ensure
251+ reproducibility across runs.
252+
253+ Args:
254+ train_pairs: List of training input/output grid pairs.
255+
256+ Returns:
257+ A 1-D numpy array of feature values. Boolean features are encoded as
258+ ``0.0`` or ``1.0``.
259+ """
260+
261+ features = extract_task_features (train_pairs )
262+
263+ numerical_keys = [
264+ 'num_train_pairs' ,
265+ 'input_height_mean' ,
266+ 'input_width_mean' ,
267+ 'output_height_mean' ,
268+ 'output_width_mean' ,
269+ 'shape_preserved' ,
270+ 'size_ratio_mean' ,
271+ 'input_colors_mean' ,
272+ 'output_colors_mean' ,
273+ 'background_color_consistent' ,
274+ 'has_color_mapping' ,
275+ 'color_mapping_size' ,
276+ 'input_objects_mean' ,
277+ 'output_objects_mean' ,
278+ 'object_count_preserved' ,
279+ 'likely_rotation' ,
280+ 'likely_reflection' ,
281+ 'likely_translation' ,
282+ 'likely_recolor' ,
283+ 'likely_crop' ,
284+ 'likely_pad' ,
285+ ]
286+
287+ values : List [float ] = []
288+ for key in numerical_keys :
289+ val = features .get (key , 0 )
290+ if isinstance (val , bool ):
291+ values .append (1.0 if val else 0.0 )
292+ else :
293+ try :
294+ values .append (float (val ))
295+ except (TypeError , ValueError ): # pragma: no cover - defensive path
296+ values .append (0.0 )
297+
298+ return np .array (values , dtype = float )
0 commit comments