1111import numpy as np
1212import os
1313import logging
14+ import json
15+ from pathlib import Path
1416
1517from .grid import to_array , to_list , Array
1618from .search import (
@@ -34,25 +36,40 @@ class ARCSolver:
3436
3537 def __init__ (self , use_enhancements : bool = True ,
3638 guidance_model_path : str = None ,
37- episode_db_path : str = "episodes.json" ):
39+ episode_db_path : str = "episodes.json" ,
40+ enable_logging : bool = None ,
41+ checkpoint_path : str = None ):
3842 self .use_enhancements = use_enhancements
3943 self .guidance_model_path = guidance_model_path
4044 self .episode_db_path = episode_db_path
45+ self .checkpoint_path = checkpoint_path or "checkpoint.json"
46+
47+ # Logging control - check environment variable or parameter
48+ if enable_logging is None :
49+ enable_logging = os .environ .get ('ARC_ENABLE_LOGGING' , 'true' ).lower () in ('1' , 'true' , 'yes' )
50+ self .enable_logging = enable_logging
51+
4152 self .stats = {
4253 'tasks_solved' : 0 ,
4354 'total_tasks' : 0 ,
4455 'enhancement_success_rate' : 0.0 ,
4556 'fallback_used' : 0 ,
4657 }
58+ self .submission_results = {} # For checkpoint saving
4759
48- # Structured logger for observability
60+ # Structured logger for observability - controlled by enable_logging
4961 self .logger = logging .getLogger (self .__class__ .__name__ )
5062 if not self .logger .handlers :
5163 handler = logging .StreamHandler ()
5264 formatter = logging .Formatter ('%(asctime)s %(name)s %(levelname)s: %(message)s' )
5365 handler .setFormatter (formatter )
5466 self .logger .addHandler (handler )
55- self .logger .setLevel (logging .INFO )
67+
68+ # Set logging level based on enable_logging flag
69+ if self .enable_logging :
70+ self .logger .setLevel (logging .INFO )
71+ else :
72+ self .logger .setLevel (logging .CRITICAL ) # Only show critical errors
5673 self ._last_outputs : Optional [Tuple [List [List [List [int ]]], List [List [List [int ]]]]] = None
5774 # Continuous memory and hypotheses
5875 self .self_memory = ContinuousSelfMemory ()
@@ -104,7 +121,8 @@ def solve_task(self, task: Dict[str, List[Dict[str, List[List[int]]]]]) -> Dict[
104121 else :
105122 # Inconsistent outputs - let enhanced search detect from test input
106123 expected_shape = None
107- self .logger .info (f"Inconsistent output shapes detected: { output_shapes } , enabling dynamic detection" )
124+ if self .enable_logging :
125+ self .logger .info (f"Inconsistent output shapes detected: { output_shapes } , enabling dynamic detection" )
108126
109127 # Generate and store hypotheses about the transformation.
110128 self ._last_hypotheses = self .hypothesis_engine .generate_hypotheses (train_pairs )
@@ -168,7 +186,8 @@ def _get_predictions(
168186 enhanced : List [List [Array ]] = []
169187 if self .use_enhancements :
170188 try :
171- self .logger .info ("Using enhanced search for prediction" )
189+ if self .enable_logging :
190+ self .logger .info ("Using enhanced search for prediction" )
172191 progs = synthesize_with_enhancements (train_pairs , expected_shape = expected_shape , test_input = test_input )
173192
174193 # Import human reasoner for enhanced prediction
@@ -179,19 +198,22 @@ def _get_predictions(
179198 human_reasoner = human_reasoner ,
180199 train_pairs = train_pairs )
181200 except Exception as e :
182- self .logger .exception ("Enhanced prediction error: %s" , e )
201+ if self .enable_logging :
202+ self .logger .exception ("Enhanced prediction error: %s" , e )
183203
184204 # Baseline predictions for ensemble
185205 progs_base = synth_baseline (train_pairs , expected_shape = expected_shape )
186206 baseline = predict_two_baseline (progs_base , [test_input ])
187207
188208 # Validate enhanced prediction
189209 if enhanced and self ._validate_solution (enhanced , [test_input ]):
190- self .logger .info (f"Enhanced prediction valid - shape: { enhanced [0 ][0 ].shape } " )
210+ if self .enable_logging :
211+ self .logger .info (f"Enhanced prediction valid - shape: { enhanced [0 ][0 ].shape } " )
191212 return [enhanced [0 ], baseline [0 ]]
192213
193214 self .stats ['fallback_used' ] += 1
194- self .logger .info ("Using baseline prediction" )
215+ if self .enable_logging :
216+ self .logger .info ("Using baseline prediction" )
195217 return baseline
196218
197219 def _postprocess_predictions (
@@ -317,11 +339,12 @@ def add_template(template: PlaceholderTemplate) -> None:
317339
318340 if templates :
319341 episodic_count = max (0 , len (templates ) - len (new_templates ))
320- self .logger .info (
321- "Loaded %d placeholder template(s) (%d from episodic memory)" ,
322- len (templates ),
323- episodic_count ,
324- )
342+ if self .enable_logging :
343+ self .logger .info (
344+ "Loaded %d placeholder template(s) (%d from episodic memory)" ,
345+ len (templates ),
346+ episodic_count ,
347+ )
325348
326349 def _persist_placeholder_templates (
327350 self , train_pairs : List [Tuple [Array , Array ]]
@@ -356,12 +379,14 @@ def _persist_placeholder_templates(
356379 metadata = {"placeholder_templates" : payloads },
357380 )
358381 self .episodic_retrieval .save ()
359- self .logger .info (
360- "Persisted %d placeholder template(s) to episodic memory" ,
361- len (payloads ),
362- )
382+ if self .enable_logging :
383+ self .logger .info (
384+ "Persisted %d placeholder template(s) to episodic memory" ,
385+ len (payloads ),
386+ )
363387 except Exception as exc :
364- self .logger .debug ("Failed to persist placeholder templates: %s" , exc )
388+ if self .enable_logging :
389+ self .logger .debug ("Failed to persist placeholder templates: %s" , exc )
365390
366391 def _compute_training_stats (
367392 self , train_pairs : List [Tuple [Array , Array ]]
@@ -860,7 +885,8 @@ def _record_continuous_experience(
860885 try :
861886 self .self_memory .record_experience (task_id , train_pairs , transformation , solved , meta )
862887 except Exception as exc :
863- self .logger .debug ("Continuous memory record failed: %s" , exc )
888+ if self .enable_logging :
889+ self .logger .debug ("Continuous memory record failed: %s" , exc )
864890
865891 def _validate_solution (self , attempts : List [List [Array ]], test_inputs : List [Array ]) -> bool :
866892 """Basic validation to check if solution seems reasonable."""
@@ -893,6 +919,69 @@ def get_statistics(self) -> Dict[str, float]:
893919 def get_persona_summary (self ) -> Dict [str , Any ]:
894920 """Expose the continuous self model summary."""
895921 return self .self_memory .persona_summary ()
922+
923+ def save_checkpoint (self , task_id : str = None , force : bool = False ) -> None :
924+ """Save current progress to checkpoint file."""
925+ if not self .submission_results and not force :
926+ return
927+
928+ try :
929+ checkpoint_data = {
930+ 'submission_results' : self .submission_results ,
931+ 'stats' : self .stats ,
932+ 'completed_tasks' : list (self .submission_results .keys ()),
933+ 'last_task' : task_id ,
934+ 'timestamp' : str (Path (__file__ ).stat ().st_mtime )
935+ }
936+
937+ with open (self .checkpoint_path , 'w' ) as f :
938+ json .dump (checkpoint_data , f , indent = 2 )
939+
940+ if self .enable_logging :
941+ self .logger .info (f"Checkpoint saved: { len (self .submission_results )} tasks completed" )
942+ except Exception as exc :
943+ if self .enable_logging :
944+ self .logger .error (f"Failed to save checkpoint: { exc } " )
945+
946+ def load_checkpoint (self ) -> Dict [str , Any ]:
947+ """Load progress from checkpoint file."""
948+ try :
949+ if not Path (self .checkpoint_path ).exists ():
950+ return {}
951+
952+ with open (self .checkpoint_path , 'r' ) as f :
953+ checkpoint_data = json .load (f )
954+
955+ self .submission_results = checkpoint_data .get ('submission_results' , {})
956+
957+ # Update stats if they exist
958+ saved_stats = checkpoint_data .get ('stats' , {})
959+ for key , value in saved_stats .items ():
960+ if key in self .stats :
961+ self .stats [key ] = value
962+
963+ if self .enable_logging :
964+ completed = len (self .submission_results )
965+ last_task = checkpoint_data .get ('last_task' , 'unknown' )
966+ self .logger .info (f"Checkpoint loaded: { completed } tasks completed, last: { last_task } " )
967+
968+ return checkpoint_data
969+ except Exception as exc :
970+ if self .enable_logging :
971+ self .logger .error (f"Failed to load checkpoint: { exc } " )
972+ return {}
973+
974+ def add_submission_result (self , task_id : str , result : Dict [str , List [List [List [int ]]]]) -> None :
975+ """Add a task result to submission tracking."""
976+ self .submission_results [task_id ] = result
977+
978+ # Save checkpoint every 10 tasks to prevent memory buildup
979+ if len (self .submission_results ) % 10 == 0 :
980+ self .save_checkpoint (task_id )
981+
982+ def get_submission_results (self ) -> Dict [str , Dict [str , List [List [List [int ]]]]]:
983+ """Get all submission results for final export."""
984+ return self .submission_results .copy ()
896985
897986
898987# Global solver instance (for backwards compatibility)
0 commit comments