1+ """
2+ Object-Agentic solver implementation.
3+
4+ This module implements the main agentic solver that coordinates multiple
5+ object-focused agents using a beam search with blackboard architecture.
6+ """
7+
8+ from typing import List , Tuple , Dict , Any , Optional , Set
9+ import numpy as np
10+ import heapq
11+ from dataclasses import dataclass
12+ from collections import defaultdict
13+
14+ from ..grid import Array , to_array , eq
15+ from ..common .objects import connected_components
16+ from ..common .invariants import (
17+ palette_equiv , object_count_invariant , evaluate_invariants
18+ )
19+ from ..common .mdl import program_length_from_ops , Operation
20+ from ..common .eval_utils import exact_match
21+ from .ops import Op , execute_program_on_grid , propose_ops_for_object , propose_global_ops , COST
22+
23+
24+ @dataclass
25+ class SearchState :
26+ """Represents a state in the beam search."""
27+ program : List [Op ]
28+ score : float
29+ train_successes : int
30+ mdl_cost : float
31+
32+ def __lt__ (self , other ):
33+ # Higher score is better, but heapq is a min-heap
34+ return self .score > other .score
35+
36+
37+ class AgentBlackboard :
38+ """
39+ Blackboard for coordinating agents and managing program construction.
40+ """
41+
42+ def __init__ (self ):
43+ self .active_programs : List [List [Op ]] = []
44+ self .constraints : Dict [str , Any ] = {}
45+ self .object_assignments : Dict [int , List [Op ]] = defaultdict (list )
46+
47+ def add_constraint (self , name : str , value : Any ):
48+ """Add a constraint that programs must satisfy."""
49+ self .constraints [name ] = value
50+
51+ def check_constraints (self , program : List [Op ], test_grid : Array ,
52+ expected_grid : Array ) -> bool :
53+ """Check if a program satisfies all constraints."""
54+ result = execute_program_on_grid (program , test_grid )
55+ if result is None :
56+ return False
57+
58+ # Check invariants
59+ invariants = evaluate_invariants (test_grid , result )
60+
61+ # Must preserve shape unless explicitly allowed
62+ if not self .constraints .get ('allow_shape_change' , False ):
63+ if not invariants ['shape_preserved' ]:
64+ return False
65+
66+ # Check object count constraints
67+ if not invariants ['object_count_stable' ]:
68+ if not self .constraints .get ('allow_object_count_change' , False ):
69+ return False
70+
71+ # Check palette constraints
72+ if self .constraints .get ('preserve_palette' , True ):
73+ if not invariants ['palette_permutation' ]:
74+ return False
75+
76+ return True
77+
78+ def score_program (self , program : List [Op ], train_pairs : List [Tuple [Array , Array ]]) -> Tuple [float , int ]:
79+ """
80+ Score a program based on training performance and MDL.
81+ Returns (score, num_successes).
82+ """
83+ if not program :
84+ return (0.0 , 0 )
85+
86+ successes = 0
87+ total_pairs = len (train_pairs )
88+
89+ for input_grid , expected_grid in train_pairs :
90+ result = execute_program_on_grid (program , input_grid )
91+ if result is not None and exact_match (result , expected_grid ):
92+ successes += 1
93+
94+ accuracy = successes / max (1 , total_pairs )
95+
96+ # Compute MDL cost
97+ ops_dict = [{'kind' : op .kind , 'params' : op .params } for op in program ]
98+ mdl_cost = program_length_from_ops (ops_dict )
99+
100+ # Score combines accuracy and simplicity
101+ # Perfect accuracy gets base score of 100, then subtract MDL cost
102+ score = accuracy * 100.0 - mdl_cost
103+
104+ return (score , successes )
105+
106+
107+ def beam_search_agentic (train_pairs : List [Tuple [Array , Array ]],
108+ beam_width : int = 64 , max_depth : int = 4 ) -> Tuple [List [Op ], float ]:
109+ """
110+ Perform beam search to find the best program for the training pairs.
111+ """
112+ if not train_pairs :
113+ return ([], 0.0 )
114+
115+ # Initialize blackboard
116+ blackboard = AgentBlackboard ()
117+
118+ # Analyze training data to set constraints
119+ first_input , first_output = train_pairs [0 ]
120+
121+ # Check if shape changes
122+ shape_changes = any (inp .shape != out .shape for inp , out in train_pairs )
123+ blackboard .add_constraint ('allow_shape_change' , shape_changes )
124+
125+ # Check if object count changes
126+ obj_count_changes = any (
127+ not object_count_invariant (inp , out ) for inp , out in train_pairs
128+ )
129+ blackboard .add_constraint ('allow_object_count_change' , obj_count_changes )
130+
131+ # Check if palette changes
132+ palette_changes = any (
133+ not palette_equiv (inp , out ) for inp , out in train_pairs
134+ )
135+ blackboard .add_constraint ('preserve_palette' , not palette_changes )
136+
137+ # Initialize beam with empty program
138+ beam = [SearchState ([], 0.0 , 0 , 0.0 )]
139+
140+ best_program = []
141+ best_score = - float ('inf' )
142+
143+ for depth in range (max_depth ):
144+ next_beam = []
145+
146+ for state in beam :
147+ # Generate successor states
148+ successors = generate_successors (state , train_pairs , blackboard )
149+ next_beam .extend (successors )
150+
151+ # Keep top beam_width states
152+ next_beam .sort (reverse = True , key = lambda s : s .score )
153+ beam = next_beam [:beam_width ]
154+
155+ # Update best program
156+ for state in beam :
157+ if state .score > best_score and state .train_successes > 0 :
158+ best_score = state .score
159+ best_program = state .program .copy ()
160+
161+ # Early stopping if we find a perfect solution
162+ if any (state .train_successes == len (train_pairs ) for state in beam ):
163+ break
164+
165+ return (best_program , best_score )
166+
167+
168+ def generate_successors (state : SearchState , train_pairs : List [Tuple [Array , Array ]],
169+ blackboard : AgentBlackboard ) -> List [SearchState ]:
170+ """Generate successor states by adding one more operation."""
171+ successors = []
172+
173+ if not train_pairs :
174+ return successors
175+
176+ # Analyze first training example to propose operations
177+ input_grid , expected_grid = train_pairs [0 ]
178+ objects = connected_components (input_grid )
179+
180+ # Propose operations for each object
181+ all_proposals = []
182+
183+ # Object-specific operations
184+ for obj_idx , obj in enumerate (objects ):
185+ context = {
186+ 'obj_idx' : obj_idx ,
187+ 'grid' : input_grid ,
188+ 'all_objects' : objects ,
189+ 'common_colors' : list (range (10 )) # Standard ARC colors
190+ }
191+ proposals = propose_ops_for_object (obj , context )
192+ all_proposals .extend (proposals )
193+
194+ # Global operations
195+ global_proposals = propose_global_ops (input_grid , objects )
196+ all_proposals .extend (global_proposals )
197+
198+ # Limit total proposals to prevent explosion
199+ all_proposals = all_proposals [:100 ]
200+
201+ for op in all_proposals :
202+ new_program = state .program + [op ]
203+
204+ # Quick constraint check
205+ test_result = execute_program_on_grid (new_program , input_grid )
206+ if test_result is None :
207+ continue
208+
209+ if not blackboard .check_constraints (new_program , input_grid , expected_grid ):
210+ continue
211+
212+ # Score the new program
213+ score , successes = blackboard .score_program (new_program , train_pairs )
214+
215+ # Compute MDL cost
216+ ops_dict = [{'kind' : op .kind , 'params' : op .params } for op in new_program ]
217+ mdl_cost = program_length_from_ops (ops_dict )
218+
219+ new_state = SearchState (new_program , score , successes , mdl_cost )
220+ successors .append (new_state )
221+
222+ return successors
223+
224+
225+ def solve_task_agentic (task : Dict [str , Any ], beam_width : int = 64 ,
226+ max_depth : int = 4 ) -> List [Array ]:
227+ """
228+ Solve an ARC task using the agentic approach.
229+
230+ Args:
231+ task: ARC task dictionary with 'train' and 'test' keys
232+ beam_width: Width of the beam search
233+ max_depth: Maximum depth of the search
234+
235+ Returns:
236+ List of predicted grids for test cases
237+ """
238+ # Extract training pairs
239+ train_pairs = []
240+ for pair in task .get ("train" , []):
241+ try :
242+ input_grid = to_array (pair ["input" ])
243+ output_grid = to_array (pair ["output" ])
244+ train_pairs .append ((input_grid , output_grid ))
245+ except Exception :
246+ continue
247+
248+ if not train_pairs :
249+ # No valid training data, return identity for test cases
250+ return [to_array (test_case ["input" ]) for test_case in task .get ("test" , [])]
251+
252+ # Find best program using beam search
253+ best_program , best_score = beam_search_agentic (
254+ train_pairs , beam_width = beam_width , max_depth = max_depth
255+ )
256+
257+ # Apply program to test cases
258+ predictions = []
259+ for test_case in task .get ("test" , []):
260+ try :
261+ test_input = to_array (test_case ["input" ])
262+ prediction = execute_program_on_grid (best_program , test_input )
263+
264+ if prediction is not None :
265+ predictions .append (prediction )
266+ else :
267+ # Fallback to identity
268+ predictions .append (test_input )
269+ except Exception :
270+ # Error in processing, use identity
271+ predictions .append (to_array ([[0 ]]))
272+
273+ return predictions
274+
275+
276+ def solve_task_agentic_dict (task : Dict [str , Any ], beam_width : int = 64 ,
277+ max_depth : int = 4 ) -> Dict [str , List [List [List [int ]]]]:
278+ """
279+ Solve an ARC task and return in the standard dictionary format.
280+ This matches the interface expected by the solver registry.
281+ """
282+ predictions = solve_task_agentic (task , beam_width , max_depth )
283+
284+ # Convert to list format
285+ pred_lists = []
286+ for pred in predictions :
287+ if isinstance (pred , np .ndarray ):
288+ pred_lists .append (pred .astype (int ).tolist ())
289+ else :
290+ pred_lists .append ([[0 ]]) # Fallback
291+
292+ return {
293+ "attempt_1" : pred_lists ,
294+ "attempt_2" : pred_lists # For now, return same predictions
295+ }
0 commit comments