Stream probe training to eliminate hidden-state corpus retention#57
Merged
thomas-schweich merged 1 commit intomainfrom Apr 12, 2026
Merged
Stream probe training to eliminate hidden-state corpus retention#57thomas-schweich merged 1 commit intomainfrom
thomas-schweich merged 1 commit intomainfrom
Conversation
The previous probe pipeline extracted all-layer hidden states for every train/val position up front (~fp32 CPU tensors, one per layer) before running any probe training. On a 512-ctx base model with a 4096-game train set this came to ~33 GB of CPU RAM, OOMing 27 GB local systems well before training began. Rewrite train_all_probes as a streaming trainer: for each epoch, iterate over game batches, run one forward pass, train every probe on the batch's valid positions, then discard activations. Validation follows the same pattern and accumulates per-layer loss / accuracy / R² / MAE in a single sweep. Peak RSS for the same 1024/256-game run drops from 26+ GB (OOM) to ~8 GB. Also compact the probe-data cache: boards now int8 (was int64), side/is- check bool, ep_square int8, halfmove_clock uint8, legal_move_counts uint16. Promotion happens at use sites in get_probe_targets. scripts/eval_probes.py: bound the probe-data cache to exactly one (max_ply, prepend_outcome) entry, and load checkpoints on CPU before moving the final model to the target device to avoid a transient 2x VRAM peak during load. The legacy helpers (_extract_hidden_states, _extract_all_hidden_states, _train_probe_all_layers, train_single_probe) are intact for tests; they are no longer called by the main pipeline.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The old probe pipeline extracted all-layer hidden states for every train/val position upfront (fp32 CPU tensors, one per layer) before any probe training began. On a 512-ctx base model with 4096 train games this came to ~33 GB of CPU RAM, OOMing 27 GB local systems before training could start. A 1024/256-game run at max_ply=512 also OOM'd at ~26 GB RSS peak.
Rewrite `train_all_probes` as a streaming trainer:
Also compact the probe data cache:
Promotion to the dtype each loss needs happens at use sites in `get_probe_targets`.
`scripts/eval_probes.py`:
Validation
Full per-layer sweep on the 300K / 512-ctx base checkpoint at 4096 train / 1024 val games, runs end-to-end on the local 27 GB / 20 GB VRAM machine:
The 4x-smaller run lands within ~0.01 of the full run on every probe, which is consistent with linear probes on 512-dim features saturating on data quickly. Halfmove clock shows the biggest movement (0.507 → 0.558) since it has the most variance to capture.
Memory stayed at ~6-10 GB RSS throughout. No OOM anywhere.
This addresses the primary recommendation in `local/PROBE_SCRIPT_REVIEW.md` (stream probe training instead of retaining the full all-layer corpus), plus the compact-dtype, bounded-cache, and CPU-first-load suggestions from the same review.
Test plan