1919from .neural .episodic import EpisodicRetrieval
2020from .neural .sketches import SketchMiner , generate_parameter_grid
2121from .ttt import TestTimeTrainer , DataAugmentation
22+ from .beam_search import beam_search
23+ from .mcts_search import mcts_search
2224
2325
2426class EnhancedSearch :
2527 """Enhanced program synthesis search with neural guidance and episodic retrieval."""
2628
27- def __init__ (self , guidance_model_path : Optional [str ] = None ,
28- episode_db_path : str = "episodes.json" ):
29+ def __init__ (self , guidance_model_path : Optional [str ] = None ,
30+ episode_db_path : str = "episodes.json" ,
31+ enable_beam_search : bool = True ):
2932 self .neural_guidance = NeuralGuidance (guidance_model_path )
3033 self .episodic_retrieval = EpisodicRetrieval (episode_db_path )
3134 self .sketch_miner = SketchMiner ()
3235 self .test_time_trainer = TestTimeTrainer ()
3336 self .search_stats = {}
37+ self .enable_beam_search = enable_beam_search
3438
3539 # Load any existing sketches
3640 try :
@@ -44,6 +48,9 @@ def synthesize_enhanced(self, train_pairs: List[Tuple[Array, Array]],
4448 self .search_stats = {
4549 'episodic_candidates' : 0 ,
4650 'heuristic_candidates' : 0 ,
51+ 'beam_candidates' : 0 ,
52+ 'beam_nodes_expanded' : 0 ,
53+ 'mcts_candidates' : 0 ,
4754 'sketch_candidates' : 0 ,
4855 'neural_guided_candidates' : 0 ,
4956 'ttt_adapted' : False ,
@@ -61,19 +68,32 @@ def synthesize_enhanced(self, train_pairs: List[Tuple[Array, Array]],
6168 all_candidates .extend (heuristic_candidates )
6269 self .search_stats ['heuristic_candidates' ] = len (heuristic_candidates )
6370
64- # Step 3: Neural-guided search if we need more candidates
71+ # Step 3: Beam search for deeper exploration
72+ if self .enable_beam_search and len (all_candidates ) < max_programs :
73+ beam_programs , stats = beam_search (train_pairs , beam_width = 16 , depth = 3 )
74+ all_candidates .extend (beam_programs )
75+ self .search_stats ['beam_candidates' ] = len (beam_programs )
76+ self .search_stats ['beam_nodes_expanded' ] = stats ['nodes_expanded' ]
77+
78+ # Step 4: Monte Carlo Tree Search if still limited
79+ if self .enable_beam_search and len (all_candidates ) < max_programs // 2 :
80+ mcts_programs = mcts_search (train_pairs , iterations = 200 , max_depth = 2 , seed = 0 )
81+ all_candidates .extend (mcts_programs )
82+ self .search_stats ['mcts_candidates' ] = len (mcts_programs )
83+
84+ # Step 5: Neural-guided search if we need more candidates
6585 if len (all_candidates ) < max_programs // 4 :
6686 neural_candidates = self ._neural_guided_search (train_pairs , max_programs // 2 )
6787 all_candidates .extend (neural_candidates )
6888 self .search_stats ['neural_guided_candidates' ] = len (neural_candidates )
69-
70- # Step 4 : Sketch-based search if still need more
89+
90+ # Step 6 : Sketch-based search if still need more
7191 if len (all_candidates ) < max_programs // 2 :
7292 sketch_candidates = self ._sketch_based_search (train_pairs , max_programs // 3 )
7393 all_candidates .extend (sketch_candidates )
7494 self .search_stats ['sketch_candidates' ] = len (sketch_candidates )
75-
76- # Step 5 : Test-time adaptation if we have candidates
95+
96+ # Step 7 : Test-time adaptation if we have candidates
7797 if all_candidates :
7898 all_candidates = self ._apply_test_time_adaptation (train_pairs , all_candidates )
7999 self .search_stats ['ttt_adapted' ] = True
0 commit comments