Skip to content

Stream probe training to eliminate hidden-state corpus retention#57

Merged
thomas-schweich merged 1 commit intomainfrom
streaming-probe-training
Apr 12, 2026
Merged

Stream probe training to eliminate hidden-state corpus retention#57
thomas-schweich merged 1 commit intomainfrom
streaming-probe-training

Conversation

@thomas-schweich
Copy link
Copy Markdown
Owner

Summary

The old probe pipeline extracted all-layer hidden states for every train/val position upfront (fp32 CPU tensors, one per layer) before any probe training began. On a 512-ctx base model with 4096 train games this came to ~33 GB of CPU RAM, OOMing 27 GB local systems before training could start. A 1024/256-game run at max_ply=512 also OOM'd at ~26 GB RSS peak.

Rewrite `train_all_probes` as a streaming trainer:

  • Per-batch forward + probe SGD + discard. No hidden-state corpus is ever materialized.
  • Validation runs as a single streaming sweep that accumulates loss / accuracy / R² / MAE per layer across the full val set.
  • Peak RSS for the same workload drops from OOM (26+ GB) to ~8 GB.
  • Legacy helpers (`_extract_hidden_states`, `_extract_all_hidden_states`, `_train_probe_all_layers`, `train_single_probe`) are intact for existing tests; the main pipeline no longer calls them.

Also compact the probe data cache:

  • `boards`: int64 → int8 (8× saving)
  • `side_to_move`, `is_check`: fp32 → bool
  • `ep_square`: int64 → int8
  • `halfmove_clock`: fp32 → uint8
  • `legal_move_counts`: fp32 → uint16

Promotion to the dtype each loss needs happens at use sites in `get_probe_targets`.

`scripts/eval_probes.py`:

  • Probe-data cache bounded to a single `(max_ply, prepend_outcome)` entry with a `gc.collect()` before regeneration.
  • CPU-first checkpoint load: weights + model instantiated on CPU, `load_state_dict`, then `.to(device)` — removes the transient 2x VRAM peak during load.

Validation

Full per-layer sweep on the 300K / 512-ctx base checkpoint at 4096 train / 1024 val games, runs end-to-end on the local 27 GB / 20 GB VRAM machine:

Probe best @ layer 4096/1024 1024/256
piece_type L5 0.898 0.895
side_to_move L2 1.000 1.000
is_check L2 0.948 0.948
castling_rights L2 0.990 0.989
ep_square L3 0.999 0.999
material_count L4 R² 0.824 R² 0.818
legal_move_count L5 R² 0.661 R² 0.656
halfmove_clock L4 R² 0.558 R² 0.507
game_phase L3 0.953 0.953

The 4x-smaller run lands within ~0.01 of the full run on every probe, which is consistent with linear probes on 512-dim features saturating on data quickly. Halfmove clock shows the biggest movement (0.507 → 0.558) since it has the most variance to capture.

Memory stayed at ~6-10 GB RSS throughout. No OOM anywhere.

This addresses the primary recommendation in `local/PROBE_SCRIPT_REVIEW.md` (stream probe training instead of retaining the full all-layer corpus), plus the compact-dtype, bounded-cache, and CPU-first-load suggestions from the same review.

Test plan

  • `uv run pytest tests/` — 1350 passed
  • `uv run --with pyright pyright pawn/eval_suite/probes.py scripts/eval_probes.py` — 0 errors
  • Streaming run at 1024/256, max_ply=512, per-layer sweep, 20 epochs — completes, ~8 GB peak RSS
  • Streaming run at 4096/1024, max_ply=512, per-layer sweep, 20 epochs — completes, ~6-10 GB peak RSS, metrics within ~0.01 of the small run

The previous probe pipeline extracted all-layer hidden states for every
train/val position up front (~fp32 CPU tensors, one per layer) before
running any probe training. On a 512-ctx base model with a 4096-game
train set this came to ~33 GB of CPU RAM, OOMing 27 GB local systems
well before training began.

Rewrite train_all_probes as a streaming trainer: for each epoch, iterate
over game batches, run one forward pass, train every probe on the
batch's valid positions, then discard activations. Validation follows
the same pattern and accumulates per-layer loss / accuracy / R² / MAE in
a single sweep. Peak RSS for the same 1024/256-game run drops from 26+
GB (OOM) to ~8 GB.

Also compact the probe-data cache: boards now int8 (was int64), side/is-
check bool, ep_square int8, halfmove_clock uint8, legal_move_counts
uint16. Promotion happens at use sites in get_probe_targets.

scripts/eval_probes.py: bound the probe-data cache to exactly one
(max_ply, prepend_outcome) entry, and load checkpoints on CPU before
moving the final model to the target device to avoid a transient 2x
VRAM peak during load.

The legacy helpers (_extract_hidden_states, _extract_all_hidden_states,
_train_probe_all_layers, train_single_probe) are intact for tests; they
are no longer called by the main pipeline.
@thomas-schweich thomas-schweich merged commit 9bf71c9 into main Apr 12, 2026
1 check passed
@thomas-schweich thomas-schweich deleted the streaming-probe-training branch April 12, 2026 01:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant