11"""Episodic memory and retrieval for the ARC solver.
22
33This module implements a lightweight yet fully functional episodic memory
4- system. Previously solved tasks (episodes) are stored together with the
5- programs that solved them and rich feature representations. At inference
6- time the solver can query this database for tasks with similar signatures or
7- feature vectors and reuse their solutions as candidates.
4+ system. Previously solved tasks (episodes) are stored together with the
5+ programs that solved them and rich feature representations. A hierarchical
6+ index organises episodes into coarse feature buckets while repeated solutions
7+ are consolidated to avoid unbounded growth. At inference time the solver can
8+ query this database for tasks with similar signatures or feature vectors and
9+ reuse their solutions as candidates.
810
911The implementation is intentionally deterministic and avoids any external
1012dependencies so that it remains compatible with the Kaggle competition
@@ -110,6 +112,10 @@ def __init__(self, db_path: Optional[str] = None) -> None:
110112 self .episodes : Dict [int , Episode ] = {}
111113 self .signature_index : Dict [str , List [int ]] = defaultdict (list )
112114 self .program_index : Dict [str , List [int ]] = defaultdict (list )
115+ # Hierarchical index groups episodes by coarse feature buckets.
116+ # This enables fast retrieval of structurally similar tasks while
117+ # keeping the system deterministic and lightweight.
118+ self .hierarchy_index : Dict [str , List [int ]] = defaultdict (list )
113119 self .db_path = db_path
114120 self ._next_id = 1
115121
@@ -127,6 +133,20 @@ def _program_key(program: Program) -> str:
127133 ]
128134 return json .dumps (normalised )
129135
136+ def _hierarchy_key (self , features : Dict [str , Any ]) -> str :
137+ """Return a coarse key used for hierarchical organisation.
138+
139+ The key buckets episodes by basic properties such as number of
140+ training pairs, average input colours and whether recolouring is
141+ likely. These buckets act as top-level memory regions that group
142+ broadly similar tasks.
143+ """
144+
145+ num_pairs = int (features .get ("num_train_pairs" , 0 ))
146+ colours = int (features .get ("input_colors_mean" , 0 ))
147+ recolor = int (bool (features .get ("likely_recolor" , False )))
148+ return f"{ num_pairs } :{ colours } :{ recolor } "
149+
130150 def _compute_similarity (self , f1 : Dict [str , Any ], f2 : Dict [str , Any ]) -> float :
131151 """Compute cosine similarity between two feature dictionaries."""
132152 numerical_keys = [
@@ -186,7 +206,6 @@ def store_episode(
186206 metadata : Optional [Dict [str , Any ]] = None ,
187207 ) -> int :
188208 """Store a solved episode and return its identifier."""
189-
190209 episode = Episode (
191210 task_signature = task_signature ,
192211 programs = programs ,
@@ -203,6 +222,8 @@ def store_episode(
203222 for program in programs :
204223 key = self ._program_key (program )
205224 self .program_index [key ].append (episode_id )
225+ hier_key = self ._hierarchy_key (episode .features )
226+ self .hierarchy_index [hier_key ].append (episode_id )
206227
207228 return episode_id
208229
@@ -235,12 +256,42 @@ def query_by_similarity(
235256 results .sort (key = lambda x : x [1 ], reverse = True )
236257 return results [:max_results ]
237258
259+ def query_hierarchy (
260+ self ,
261+ train_pairs : List [Tuple [Array , Array ]],
262+ similarity_threshold : float = 0.5 ,
263+ max_results : int = 5 ,
264+ ) -> List [Tuple [Episode , float ]]:
265+ """Return episodes from the same hierarchical bucket.
266+
267+ Episodes are grouped into coarse buckets based on simple features.
268+ This allows a two-level lookup: first by bucket, then by detailed
269+ similarity within that bucket.
270+ """
271+
272+ if not train_pairs :
273+ return []
274+ query_features = extract_task_features (train_pairs )
275+ key = self ._hierarchy_key (query_features )
276+ ids = self .hierarchy_index .get (key , [])
277+ results : List [Tuple [Episode , float ]] = []
278+ for eid in ids :
279+ episode = self .episodes [eid ]
280+ similarity = self ._compute_similarity (query_features , episode .features )
281+ if similarity >= similarity_threshold :
282+ results .append ((episode , similarity ))
283+ results .sort (key = lambda x : x [1 ], reverse = True )
284+ return results [:max_results ]
285+
238286 def get_candidate_programs (
239287 self , train_pairs : List [Tuple [Array , Array ]], max_programs : int = 10
240288 ) -> List [Program ]:
241289 """Return programs from similar episodes for reuse."""
242290 candidates : List [Program ] = []
243- for episode , _ in self .query_by_similarity (train_pairs , 0.0 , max_programs ):
291+ results = self .query_hierarchy (train_pairs , 0.0 , max_programs )
292+ if not results :
293+ results = self .query_by_similarity (train_pairs , 0.0 , max_programs )
294+ for episode , _ in results :
244295 for program in episode .programs :
245296 candidates .append (program )
246297 if len (candidates ) >= max_programs :
@@ -260,6 +311,30 @@ def remove_episode(self, episode_id: int) -> None:
260311 self .program_index [key ] = [
261312 i for i in self .program_index [key ] if i != episode_id
262313 ]
314+ hier_key = self ._hierarchy_key (episode .features )
315+ self .hierarchy_index [hier_key ] = [
316+ i for i in self .hierarchy_index [hier_key ] if i != episode_id
317+ ]
318+
319+ def consolidate (self ) -> None :
320+ """Merge episodes with identical signature and program set."""
321+
322+ signature_map : Dict [Tuple [str , str ], int ] = {}
323+ to_remove : List [int ] = []
324+ for eid , episode in self .episodes .items ():
325+ program_key = json .dumps (
326+ sorted (self ._program_key (p ) for p in episode .programs )
327+ )
328+ key = (episode .task_signature , program_key )
329+ if key in signature_map :
330+ target_id = signature_map [key ]
331+ self .episodes [target_id ].success_count += episode .success_count
332+ to_remove .append (eid )
333+ else :
334+ signature_map [key ] = eid
335+
336+ for eid in to_remove :
337+ self .remove_episode (eid )
263338
264339 # ------------------------------------------------------------------
265340 # Persistence
@@ -296,11 +371,14 @@ def load(self, filepath: Optional[str] = None) -> None:
296371 # Rebuild indexes deterministically
297372 self .signature_index .clear ()
298373 self .program_index .clear ()
374+ self .hierarchy_index .clear ()
299375 for eid , episode in self .episodes .items ():
300376 self .signature_index [episode .task_signature ].append (eid )
301377 for program in episode .programs :
302378 key = self ._program_key (program )
303379 self .program_index [key ].append (eid )
380+ hier_key = self ._hierarchy_key (episode .features )
381+ self .hierarchy_index [hier_key ].append (eid )
304382
305383 # ------------------------------------------------------------------
306384 # Statistics
0 commit comments