Skip to content

Commit d42ec6f

Browse files
tylerbessireclaude
andcommitted
Add checkpoint saving and logging toggle for memory optimization
- Added checkpoint_path and enable_logging parameters to ARCSolver - Implemented save_checkpoint/load_checkpoint with automatic saves every 10 tasks - Added submission_results tracking for progress recovery - Created logging toggle controlled by ARC_ENABLE_LOGGING environment variable - All logging statements now respect enable_logging flag to reduce memory usage - Checkpoint saves task results, stats, and progress metadata - Supports resuming from checkpoint after memory crashes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent dcba83a commit d42ec6f

File tree

3 files changed

+660
-30
lines changed

3 files changed

+660
-30
lines changed

arc_solver/solver.py

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
import os
1313
import logging
14+
import json
15+
from pathlib import Path
1416

1517
from .grid import to_array, to_list, Array
1618
from .search import (
@@ -34,25 +36,40 @@ class ARCSolver:
3436

3537
def __init__(self, use_enhancements: bool = True,
3638
guidance_model_path: str = None,
37-
episode_db_path: str = "episodes.json"):
39+
episode_db_path: str = "episodes.json",
40+
enable_logging: bool = None,
41+
checkpoint_path: str = None):
3842
self.use_enhancements = use_enhancements
3943
self.guidance_model_path = guidance_model_path
4044
self.episode_db_path = episode_db_path
45+
self.checkpoint_path = checkpoint_path or "checkpoint.json"
46+
47+
# Logging control - check environment variable or parameter
48+
if enable_logging is None:
49+
enable_logging = os.environ.get('ARC_ENABLE_LOGGING', 'true').lower() in ('1', 'true', 'yes')
50+
self.enable_logging = enable_logging
51+
4152
self.stats = {
4253
'tasks_solved': 0,
4354
'total_tasks': 0,
4455
'enhancement_success_rate': 0.0,
4556
'fallback_used': 0,
4657
}
58+
self.submission_results = {} # For checkpoint saving
4759

48-
# Structured logger for observability
60+
# Structured logger for observability - controlled by enable_logging
4961
self.logger = logging.getLogger(self.__class__.__name__)
5062
if not self.logger.handlers:
5163
handler = logging.StreamHandler()
5264
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s')
5365
handler.setFormatter(formatter)
5466
self.logger.addHandler(handler)
55-
self.logger.setLevel(logging.INFO)
67+
68+
# Set logging level based on enable_logging flag
69+
if self.enable_logging:
70+
self.logger.setLevel(logging.INFO)
71+
else:
72+
self.logger.setLevel(logging.CRITICAL) # Only show critical errors
5673
self._last_outputs: Optional[Tuple[List[List[List[int]]], List[List[List[int]]]]] = None
5774
# Continuous memory and hypotheses
5875
self.self_memory = ContinuousSelfMemory()
@@ -104,7 +121,8 @@ def solve_task(self, task: Dict[str, List[Dict[str, List[List[int]]]]]) -> Dict[
104121
else:
105122
# Inconsistent outputs - let enhanced search detect from test input
106123
expected_shape = None
107-
self.logger.info(f"Inconsistent output shapes detected: {output_shapes}, enabling dynamic detection")
124+
if self.enable_logging:
125+
self.logger.info(f"Inconsistent output shapes detected: {output_shapes}, enabling dynamic detection")
108126

109127
# Generate and store hypotheses about the transformation.
110128
self._last_hypotheses = self.hypothesis_engine.generate_hypotheses(train_pairs)
@@ -168,7 +186,8 @@ def _get_predictions(
168186
enhanced: List[List[Array]] = []
169187
if self.use_enhancements:
170188
try:
171-
self.logger.info("Using enhanced search for prediction")
189+
if self.enable_logging:
190+
self.logger.info("Using enhanced search for prediction")
172191
progs = synthesize_with_enhancements(train_pairs, expected_shape=expected_shape, test_input=test_input)
173192

174193
# Import human reasoner for enhanced prediction
@@ -179,19 +198,22 @@ def _get_predictions(
179198
human_reasoner=human_reasoner,
180199
train_pairs=train_pairs)
181200
except Exception as e:
182-
self.logger.exception("Enhanced prediction error: %s", e)
201+
if self.enable_logging:
202+
self.logger.exception("Enhanced prediction error: %s", e)
183203

184204
# Baseline predictions for ensemble
185205
progs_base = synth_baseline(train_pairs, expected_shape=expected_shape)
186206
baseline = predict_two_baseline(progs_base, [test_input])
187207

188208
# Validate enhanced prediction
189209
if enhanced and self._validate_solution(enhanced, [test_input]):
190-
self.logger.info(f"Enhanced prediction valid - shape: {enhanced[0][0].shape}")
210+
if self.enable_logging:
211+
self.logger.info(f"Enhanced prediction valid - shape: {enhanced[0][0].shape}")
191212
return [enhanced[0], baseline[0]]
192213

193214
self.stats['fallback_used'] += 1
194-
self.logger.info("Using baseline prediction")
215+
if self.enable_logging:
216+
self.logger.info("Using baseline prediction")
195217
return baseline
196218

197219
def _postprocess_predictions(
@@ -317,11 +339,12 @@ def add_template(template: PlaceholderTemplate) -> None:
317339

318340
if templates:
319341
episodic_count = max(0, len(templates) - len(new_templates))
320-
self.logger.info(
321-
"Loaded %d placeholder template(s) (%d from episodic memory)",
322-
len(templates),
323-
episodic_count,
324-
)
342+
if self.enable_logging:
343+
self.logger.info(
344+
"Loaded %d placeholder template(s) (%d from episodic memory)",
345+
len(templates),
346+
episodic_count,
347+
)
325348

326349
def _persist_placeholder_templates(
327350
self, train_pairs: List[Tuple[Array, Array]]
@@ -356,12 +379,14 @@ def _persist_placeholder_templates(
356379
metadata={"placeholder_templates": payloads},
357380
)
358381
self.episodic_retrieval.save()
359-
self.logger.info(
360-
"Persisted %d placeholder template(s) to episodic memory",
361-
len(payloads),
362-
)
382+
if self.enable_logging:
383+
self.logger.info(
384+
"Persisted %d placeholder template(s) to episodic memory",
385+
len(payloads),
386+
)
363387
except Exception as exc:
364-
self.logger.debug("Failed to persist placeholder templates: %s", exc)
388+
if self.enable_logging:
389+
self.logger.debug("Failed to persist placeholder templates: %s", exc)
365390

366391
def _compute_training_stats(
367392
self, train_pairs: List[Tuple[Array, Array]]
@@ -860,7 +885,8 @@ def _record_continuous_experience(
860885
try:
861886
self.self_memory.record_experience(task_id, train_pairs, transformation, solved, meta)
862887
except Exception as exc:
863-
self.logger.debug("Continuous memory record failed: %s", exc)
888+
if self.enable_logging:
889+
self.logger.debug("Continuous memory record failed: %s", exc)
864890

865891
def _validate_solution(self, attempts: List[List[Array]], test_inputs: List[Array]) -> bool:
866892
"""Basic validation to check if solution seems reasonable."""
@@ -893,6 +919,69 @@ def get_statistics(self) -> Dict[str, float]:
893919
def get_persona_summary(self) -> Dict[str, Any]:
894920
"""Expose the continuous self model summary."""
895921
return self.self_memory.persona_summary()
922+
923+
def save_checkpoint(self, task_id: str = None, force: bool = False) -> None:
924+
"""Save current progress to checkpoint file."""
925+
if not self.submission_results and not force:
926+
return
927+
928+
try:
929+
checkpoint_data = {
930+
'submission_results': self.submission_results,
931+
'stats': self.stats,
932+
'completed_tasks': list(self.submission_results.keys()),
933+
'last_task': task_id,
934+
'timestamp': str(Path(__file__).stat().st_mtime)
935+
}
936+
937+
with open(self.checkpoint_path, 'w') as f:
938+
json.dump(checkpoint_data, f, indent=2)
939+
940+
if self.enable_logging:
941+
self.logger.info(f"Checkpoint saved: {len(self.submission_results)} tasks completed")
942+
except Exception as exc:
943+
if self.enable_logging:
944+
self.logger.error(f"Failed to save checkpoint: {exc}")
945+
946+
def load_checkpoint(self) -> Dict[str, Any]:
947+
"""Load progress from checkpoint file."""
948+
try:
949+
if not Path(self.checkpoint_path).exists():
950+
return {}
951+
952+
with open(self.checkpoint_path, 'r') as f:
953+
checkpoint_data = json.load(f)
954+
955+
self.submission_results = checkpoint_data.get('submission_results', {})
956+
957+
# Update stats if they exist
958+
saved_stats = checkpoint_data.get('stats', {})
959+
for key, value in saved_stats.items():
960+
if key in self.stats:
961+
self.stats[key] = value
962+
963+
if self.enable_logging:
964+
completed = len(self.submission_results)
965+
last_task = checkpoint_data.get('last_task', 'unknown')
966+
self.logger.info(f"Checkpoint loaded: {completed} tasks completed, last: {last_task}")
967+
968+
return checkpoint_data
969+
except Exception as exc:
970+
if self.enable_logging:
971+
self.logger.error(f"Failed to load checkpoint: {exc}")
972+
return {}
973+
974+
def add_submission_result(self, task_id: str, result: Dict[str, List[List[List[int]]]]) -> None:
975+
"""Add a task result to submission tracking."""
976+
self.submission_results[task_id] = result
977+
978+
# Save checkpoint every 10 tasks to prevent memory buildup
979+
if len(self.submission_results) % 10 == 0:
980+
self.save_checkpoint(task_id)
981+
982+
def get_submission_results(self) -> Dict[str, Dict[str, List[List[List[int]]]]]:
983+
"""Get all submission results for final export."""
984+
return self.submission_results.copy()
896985

897986

898987
# Global solver instance (for backwards compatibility)

checkpoint.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"submission_results": {
3+
"test_task_2": {
4+
"attempt_1": [
5+
[
6+
[
7+
5,
8+
6
9+
]
10+
]
11+
],
12+
"attempt_2": [
13+
[
14+
[
15+
7,
16+
8
17+
]
18+
]
19+
]
20+
}
21+
},
22+
"stats": {
23+
"tasks_solved": 0,
24+
"total_tasks": 0,
25+
"enhancement_success_rate": 0.0,
26+
"fallback_used": 0
27+
},
28+
"completed_tasks": [
29+
"test_task_2"
30+
],
31+
"last_task": "test_task_2",
32+
"timestamp": "1758424315.4060667"
33+
}

0 commit comments

Comments
 (0)