Skip to content

Commit 061b508

Browse files
Adopt searchless_chess 1968-action vocabulary (#52)
* Adopt searchless_chess 1968-action vocabulary as primary token representation Replace PAWN's dense 4284-token vocabulary (64×64 grid + promotions + outcomes) with DeepMind's searchless_chess 1968-action vocabulary. Actions 0-1967 are 1:1 with searchless_chess — no offset, no remapping. New token layout: 0-1967: searchless_chess actions (reachable moves only) 1968: PAD 1969-1979: outcome tokens (11 total) Key changes: - Generate engine/src/searchless_vocab.rs from canonical JSON via codegen script - Rewrite vocab.rs: lookup-table tokenization replaces formula-based approach - board.rs: move_to_token builds UCI string and looks up action index - batch.rs: PAD initialization uses 1968 instead of 0 - Add pawn_to_searchless/searchless_to_pawn conversion functions - Update all test assertions for new token ranges All 304 Rust tests pass. * Add prepend_outcome flag and fix PAD initialization in Rust engine - generate_clm_batch() gains prepend_outcome parameter (default false) - false: pure moves [m1, m2, ..., mN, PAD], max_ply = seq_len - true: [outcome, m1, ..., mN, PAD], max_ply = seq_len - 1 - Fix PAD initialization: input_ids/targets use PAD_TOKEN (1968) instead of 0 (which is now action "a1b1") - Fix sparse legal mask PAD insertion to use vocab::PAD_TOKEN - Expose pawn_to_searchless/searchless_to_pawn to Python - Update all test assertions for new token ranges 304 Rust tests pass. Python smoke test verified. * Update Python config and model for searchless_chess vocabulary config.py: - New defaults: PAD_TOKEN=1968, OUTCOME_TOKEN_BASE=1969, vocab_size=1980, max_seq_len=512 - Add LegacyVocab class for loading old checkpoints (vocab_size=4284) - TrainingConfig.max_ply defaults to 512 (matches seq_len) model.py: - CLMEmbedding derives pad_token, outcome_base, n_actions from cfg.vocab_size — works with both new (1980) and legacy (4284) vocab - _build_decomposition_table(n_actions) is parameterized: builds from engine vocab for new, from old formula for legacy - Factored embeddings produce identical (src, dst, promo) decomposition for the same UCI move regardless of vocab Both PAWNCLM(CLMConfig()) and PAWNCLM(CLMConfig(vocab_size=4284)) produce correct output shapes and decomposition tables. * Update Python data pipeline for searchless_chess vocabulary data.py: - pack_clm_sequences uses torch.full(PAD_TOKEN) instead of torch.zeros - CLMDataset passes prepend_outcome flag to engine (default false) - create_validation_set gains prepend_outcome parameter lichess_data.py: - LegalMaskBuilder default vocab_size: 4278 → 1968 - LegalMaskCollate default vocab_size: 4278 → 1968 - compute_legal_indices default vocab_size: 4278 → 1968 generation.py: - Use model.cfg.vocab_size instead of class default for legal mask size * Update all tests and remaining Python code for searchless_chess vocabulary Tests (14 files, 1328 pass): - test_config.py: new constants (PAD=1968, outcomes at 1969+, vocab_size=1980) - test_model.py: decomp table shape [1968,3], token ranges [0,1967] - test_data.py: PAD=1968 in pack_clm_sequences assertions - test_clm_format.py: engine calls use prepend_outcome=True for outcome-format tests, new token ranges for all assertions - test_512_token.py: updated for new vocab + no-outcome default - test_rosa.py, test_specialized_clm.py: token range fixes - test_lichess_data.py: remove hardcoded vocab_size=4278 - test_trainer.py: updated for new defaults Python source: - probes.py: no_outcome_token defaults to True (no stripping needed) - specialized_clm.py: handle PAD_TOKEN > vocab_size gracefully - generation.py: get_legal_token_masks_batch default vocab_size=1980 - trainer.py: updated by background agent for new vocab defaults * Add --legacy-vocab flag and resume validation to training scripts train.py: - PretrainConfig.max_seq_len defaults to 512 (was 256) - --legacy-vocab flag sets vocab_size=4284, max_seq_len=256 - Resume path validates checkpoint vocab_size matches model train_all.py: - --legacy-vocab flag applies to all variant configs - max_ply derived from model_cfg.max_seq_len (512 or 256) * Add backward compatibility tests and update CLAUDE.md tests/model/test_backward_compat.py (10 tests): - Old/new model instantiation and forward pass - Factored embedding equivalence across vocabs - pawn_to_searchless/searchless_to_pawn roundtrips - Impossible move conversion returns -1 - CLM batch format verification for both vocab modes CLAUDE.md: - Token vocabulary: 1968 actions + 1 PAD + 11 outcomes = 1980 - Sequence format: pure moves (512 tokens), outcome prefix optional - max_seq_len: 256 → 512 - Document prepend_outcome flag and conversion functions * Remove VOCAB_TRANSITION.md planning document * Fix review issues: legacy decomp off-by-one, data corruption, legality metric Bug fixes: - Legacy decomp table: allocate n_actions+1 rows so token 4272 (last promo) is reachable. Fix clamp in forward() to use decomp_table.shape[0]-1. Add test for token 4272 vs 4271 aliasing. - strip_outcome_token: guard with `prepend_outcome` check to prevent silently dropping first move when --no-outcome-token is used with the new default (prepend_outcome=False). - Legality metric: parameterize compute_legal_move_rate_from_preds by n_actions so legacy-vocab predictions use the correct grid index mapping. Cache grid index tensors per (n_actions, device). - --legacy-vocab: set config.max_seq_len=256 before the override block so it doesn't get clobbered to 512. - SpecializedCLM: derive padding_idx from vocab_size (0 for legacy, 1968 for new, None for toy). Performance: - move_to_token: use stack [u8; 5] buffer instead of String allocation. Eliminates ~1.3M heap allocs/sec on the hot game generation path. Rust: - Add test_clm_batch_seq_len_consistency_no_outcome asserting max_ply == seq_len when prepend_outcome=false. 305 Rust tests, 1339 Python tests pass. * Pass n_actions to legality metric in train_all.py and backfill_metrics.py - train_all.py: pass model.embed.n_actions to compute_legal_move_rate_from_preds - backfill_metrics.py: pass model.embed.n_actions to compute_legal_move_rate - compute_legal_move_rate wrapper: forward n_actions parameter - train.py: fix misleading comment on --legacy-vocab max_seq_len override * Deprecate no_outcome_token flag, fix stale docstring - CLMTrainer: emit DeprecationWarning when no_outcome_token is set, stop passing it to CLMDataset/create_validation_set (it was a no-op since prepend_outcome defaults to False) - pack_clm_sequences docstring: remove hardcoded "(256)" from seq_len
1 parent 805548c commit 061b508

40 files changed

+8360
-1078
lines changed

CLAUDE.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pawn/
1010
├── pawn/ # Core Python package
1111
│ ├── config.py # CLMConfig (small/base/large), TrainingConfig
1212
│ ├── model.py # PAWNCLM transformer (RMSNorm, SwiGLU, RoPE, factored embeddings)
13-
│ ├── data.py # On-the-fly random game data pipeline
13+
│ ├── data.py # On-the-fly random game data pipeline (prepend_outcome flag)
1414
│ ├── lichess_data.py # Lichess PGN data pipeline + legal mask computation
1515
│ ├── trainer.py # Pretraining loop
1616
│ ├── gpu.py # GPU auto-detection (compile/AMP/SDPA backend)
@@ -55,14 +55,15 @@ The only extras are GPU backends (`rocm` or `cu128`). Everything else (pytest, s
5555
- Uses rayon for parallel game generation (~43K games/sec, 150M+/hr)
5656
- PyO3 bindings expose `chess_engine` module to Python
5757
- Key functions: `generate_random_games()`, `parse_pgn_file()`, `compute_legal_token_masks_sparse()`, `extract_board_states()`, `export_move_vocabulary()`, `compute_accuracy_ceiling()`
58+
- `export_move_vocabulary()` returns 1,968-entry maps (searchless_chess compatible). Conversion functions `pawn_to_searchless()` and `searchless_to_pawn()` bridge between legacy PAWN token IDs and searchless_chess action indices.
5859

5960
## Model
6061

6162
### Architecture
62-
- Decoder-only transformer, next-token prediction over 4,278 tokens
63-
- Token vocabulary: 1 PAD + 4,096 grid (64x64 src/dst) + 176 promotions + 5 outcomes
63+
- Decoder-only transformer, next-token prediction over 1,968 move tokens (1,980 total vocab)
64+
- Token vocabulary: 1,968 searchless_chess actions (0-1967) + 1 PAD (1968) + 11 outcomes (1969-1979) = 1,980 total
6465
- Factored embeddings: `src_embed[s] + dst_embed[d] + promo_embed[p]`
65-
- Sequence format: `[outcome] [ply_1] ... [ply_N] [PAD] ... [PAD]` (256 tokens)
66+
- Sequence format: `[ply_1] ... [ply_N] [PAD] ... [PAD]` (512 tokens) — outcome prefix is optional via `prepend_outcome` flag
6667

6768
### Variants
6869
- `CLMConfig.small()`: d=256, 8 layers, 4 heads, ~9.5M params
@@ -373,7 +374,7 @@ Supports all adapter types + architecture search. GPU affinity assigns `CUDA_VIS
373374
- **DataLoader workers must use `multiprocessing_context='spawn'`** — the Rust engine uses rayon, and fork after rayon init causes deadlocks.
374375
- **`SDPA_BACKEND` must be set before `torch.compile()`** — compiled code captures the backend at trace time. `apply_gpu_config()` handles this.
375376
- **ROCm works**: The only known ROCm issue is a stride mismatch in flash attention backward when combined with `torch.compile` + AMP. The workaround is `--sdpa-math` (use the MATH SDPA backend instead of flash), which `configure_gpu()` applies automatically on AMD GPUs. Everything else — training, eval, adapters, data loading — works identically on ROCm and CUDA. **Do not assume bugs are ROCm-specific.** Every other time something has failed on AMD it turned out to be a bug in our code (wrong torch version installed, stale lockfile, missing dependency, etc.), not a ROCm issue.
376-
- **Sparse logit projection**: `forward_hidden()` returns `(B,T,d_model)`, then only loss-masked positions project through `lm_head` — avoids full `(B,T,V)` materialization.
377+
- **Sparse logit projection**: `forward_hidden()` returns `(B,T,d_model)`, then only loss-masked positions project through `lm_head` — avoids full `(B,T,1980)` materialization.
377378
- **Legal mask via Rust**: `LegalMaskBuilder` replays games in Rust, returns sparse indices (~2 MB) scattered into a pre-allocated GPU buffer (vs ~70 MB dense).
378379
- **GPU auto-detection**: `pawn.gpu.configure_gpu()` selects compile/AMP/SDPA settings. `apply_gpu_config()` applies them. NVIDIA uses flash attention + compile; AMD uses MATH SDPA + compile. Both paths are tested and production-validated.
379380
- **Factored embeddings**: each move token decomposes into `src_embed[s] + dst_embed[d] + promo_embed[p]`, reducing embedding parameters by ~32x.

engine/python/chess_engine/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
generate_engine_games_py as generate_engine_games,
3939
# Vocabulary
4040
export_move_vocabulary,
41+
pawn_to_searchless,
42+
searchless_to_pawn,
4143
# Interactive game state (for RL)
4244
PyGameState,
4345
PyBatchRLEnv,
@@ -73,6 +75,8 @@
7375
"pgn_to_uci",
7476
"generate_engine_games",
7577
"export_move_vocabulary",
78+
"pawn_to_searchless",
79+
"searchless_to_pawn",
7680
"PyGameState",
7781
"PyBatchRLEnv",
7882
"compute_accuracy_ceiling",

engine/src/batch.rs

Lines changed: 74 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -281,58 +281,71 @@ pub struct CLMBatch {
281281

282282
/// Generate a CLM training batch: random games packed into model-ready format.
283283
///
284-
/// `seq_len` is the total sequence length (256). Games are generated with up to
285-
/// `seq_len - 1` plies, leaving position 0 for the outcome token.
284+
/// When `prepend_outcome` is false (default), sequences are pure moves:
285+
/// `[move_1, move_2, ..., move_N, PAD, ...]` and `max_ply = seq_len`.
286+
///
287+
/// When `prepend_outcome` is true, position 0 is the outcome token:
288+
/// `[outcome, move_1, ..., move_N, PAD, ...]` and `max_ply = seq_len - 1`.
286289
pub fn generate_clm_batch(
287290
batch_size: usize,
288291
seq_len: usize,
289292
seed: u64,
290293
discard_ply_limit: bool,
291294
mate_boost: f64,
295+
prepend_outcome: bool,
292296
) -> CLMBatch {
293-
let max_ply = seq_len - 1;
297+
let max_ply = if prepend_outcome { seq_len - 1 } else { seq_len };
294298

295299
let game_batch = {
296300
generate_random_games(batch_size, max_ply, seed, mate_boost, discard_ply_limit)
297301
};
298302

299-
let mut input_ids = vec![0i16; batch_size * seq_len];
300-
let mut targets = vec![0i16; batch_size * seq_len];
303+
let pad = vocab::PAD_TOKEN as i16;
304+
let mut input_ids = vec![pad; batch_size * seq_len];
305+
let mut targets = vec![pad; batch_size * seq_len];
301306
let mut loss_mask = vec![false; batch_size * seq_len];
302307

303308
for b in 0..batch_size {
304309
let gl = game_batch.game_lengths[b] as usize;
305-
let term = match game_batch.termination_codes[b] {
306-
0 => Termination::Checkmate,
307-
1 => Termination::Stalemate,
308-
2 => Termination::SeventyFiveMoveRule,
309-
3 => Termination::FivefoldRepetition,
310-
4 => Termination::InsufficientMaterial,
311-
_ => Termination::PlyLimit,
312-
};
313-
let outcome = vocab::termination_to_outcome(term, game_batch.game_lengths[b] as u16);
314-
315310
let row = b * seq_len;
316311

317-
// Position 0: outcome token
318-
input_ids[row] = outcome as i16;
312+
if prepend_outcome {
313+
// Outcome-prefixed format: [outcome, m1, ..., mN, PAD, ...]
314+
let term = match game_batch.termination_codes[b] {
315+
0 => Termination::Checkmate,
316+
1 => Termination::Stalemate,
317+
2 => Termination::SeventyFiveMoveRule,
318+
3 => Termination::FivefoldRepetition,
319+
4 => Termination::InsufficientMaterial,
320+
_ => Termination::PlyLimit,
321+
};
322+
let outcome = vocab::termination_to_outcome(term, game_batch.game_lengths[b] as u16);
323+
input_ids[row] = outcome as i16;
324+
325+
for t in 0..gl {
326+
input_ids[row + 1 + t] = game_batch.move_ids[b * max_ply + t];
327+
}
328+
329+
// Loss mask: positions 0..=gl are true
330+
for t in 0..=gl {
331+
loss_mask[row + t] = true;
332+
}
333+
} else {
334+
// Pure moves format: [m1, m2, ..., mN, PAD, ...]
335+
for t in 0..gl {
336+
input_ids[row + t] = game_batch.move_ids[b * max_ply + t];
337+
}
319338

320-
// Positions 1..=gl: move tokens
321-
for t in 0..gl {
322-
input_ids[row + 1 + t] = game_batch.move_ids[b * max_ply + t];
339+
// Loss mask: positions 0..gl-1 are true (gl positions predict gl targets)
340+
for t in 0..gl {
341+
loss_mask[row + t] = true;
342+
}
323343
}
324-
// Remaining positions are already 0 (PAD)
325344

326345
// Targets: input_ids shifted left by 1
327346
for t in 0..(seq_len - 1) {
328347
targets[row + t] = input_ids[row + t + 1];
329348
}
330-
// targets[row + seq_len - 1] is already 0
331-
332-
// Loss mask: positions 0..=gl are true
333-
for t in 0..=gl {
334-
loss_mask[row + t] = true;
335-
}
336349
}
337350

338351
CLMBatch {
@@ -408,7 +421,7 @@ mod tests {
408421
#[test]
409422
fn test_clm_batch_format() {
410423
let seq_len = 256;
411-
let batch = generate_clm_batch(8, seq_len, 42, false, 0.0);
424+
let batch = generate_clm_batch(8, seq_len, 42, false, 0.0, true);
412425
assert_eq!(batch.input_ids.len(), 8 * seq_len);
413426
assert_eq!(batch.targets.len(), 8 * seq_len);
414427
assert_eq!(batch.loss_mask.len(), 8 * seq_len);
@@ -419,21 +432,23 @@ mod tests {
419432
let gl = batch.game_lengths[b] as usize;
420433
let row = b * seq_len;
421434

422-
// Position 0: outcome token (4273-4277)
435+
let pad = vocab::PAD_TOKEN as i16;
436+
437+
// Position 0: outcome token
423438
let outcome = batch.input_ids[row];
424439
assert!(outcome >= vocab::OUTCOME_BASE as i16 && outcome <= vocab::PLY_LIMIT as i16,
425440
"Position 0 should be outcome token, got {}", outcome);
426441

427-
// Positions 1..=gl: move tokens (1-4272)
442+
// Positions 1..=gl: move tokens (action IDs 0..1967)
428443
for t in 1..=gl {
429444
let tok = batch.input_ids[row + t];
430-
assert!(tok >= 1 && tok <= 4272,
431-
"Position {} should be move token, got {}", t, tok);
445+
assert!(tok >= 0 && tok < vocab::NUM_ACTIONS as i16,
446+
"Position {} should be move token (0-1967), got {}", t, tok);
432447
}
433448

434-
// Positions gl+1..seq_len: PAD (0)
449+
// Positions gl+1..seq_len: PAD
435450
for t in (gl + 1)..seq_len {
436-
assert_eq!(batch.input_ids[row + t], 0,
451+
assert_eq!(batch.input_ids[row + t], pad,
437452
"Position {} should be PAD, got {}", t, batch.input_ids[row + t]);
438453
}
439454

@@ -442,10 +457,10 @@ mod tests {
442457
assert_eq!(batch.targets[row + t], batch.input_ids[row + t + 1],
443458
"targets[{}] should equal input_ids[{}]", t, t + 1);
444459
}
445-
assert_eq!(batch.targets[row + seq_len - 1], 0, "Last target should be PAD");
460+
assert_eq!(batch.targets[row + seq_len - 1], pad, "Last target should be PAD");
446461

447462
// Target at position gl is PAD (end of game)
448-
assert_eq!(batch.targets[row + gl], 0, "Target at game_length should be PAD");
463+
assert_eq!(batch.targets[row + gl], pad, "Target at game_length should be PAD");
449464

450465
// Loss mask: true for 0..=gl, false after
451466
for t in 0..=gl {
@@ -461,8 +476,8 @@ mod tests {
461476

462477
#[test]
463478
fn test_clm_batch_deterministic() {
464-
let b1 = generate_clm_batch(4, 256, 99, false, 0.0);
465-
let b2 = generate_clm_batch(4, 256, 99, false, 0.0);
479+
let b1 = generate_clm_batch(4, 256, 99, false, 0.0, true);
480+
let b2 = generate_clm_batch(4, 256, 99, false, 0.0, true);
466481
assert_eq!(b1.input_ids, b2.input_ids);
467482
assert_eq!(b1.targets, b2.targets);
468483
assert_eq!(b1.loss_mask, b2.loss_mask);
@@ -471,7 +486,7 @@ mod tests {
471486

472487
#[test]
473488
fn test_clm_batch_outcome_correctness() {
474-
let batch = generate_clm_batch(32, 256, 42, false, 0.0);
489+
let batch = generate_clm_batch(32, 256, 42, false, 0.0, true);
475490
for b in 0..32 {
476491
let gl = batch.game_lengths[b] as usize;
477492
let tc = batch.termination_codes[b];
@@ -503,8 +518,8 @@ mod tests {
503518
let gl = batch.game_lengths[b] as usize;
504519
for t in 0..gl {
505520
let tok = batch.move_ids[b * 64 + t];
506-
// Tokens should be valid move tokens (1..=4272)
507-
assert!(tok >= 1 && tok <= 4272,
521+
// Tokens should be valid action IDs (0..1967)
522+
assert!(tok >= 0 && tok < vocab::NUM_ACTIONS as i16,
508523
"Invalid token at b={} t={}: {}", b, t, tok);
509524
}
510525
}
@@ -685,18 +700,27 @@ mod tests {
685700
}
686701

687702
#[test]
688-
fn test_clm_batch_seq_len_consistency() {
689-
let batch = generate_clm_batch(4, 32, 42, false, 0.0);
703+
fn test_clm_batch_seq_len_consistency_with_outcome() {
704+
let batch = generate_clm_batch(4, 32, 42, false, 0.0, true);
690705
assert_eq!(batch.seq_len, 32);
691-
assert_eq!(batch.max_ply, 31); // seq_len - 1
706+
assert_eq!(batch.max_ply, 31); // seq_len - 1 when outcome prepended
692707
assert_eq!(batch.input_ids.len(), 4 * 32);
693708
assert_eq!(batch.move_ids.len(), 4 * 31);
694709
}
695710

711+
#[test]
712+
fn test_clm_batch_seq_len_consistency_no_outcome() {
713+
let batch = generate_clm_batch(4, 32, 42, false, 0.0, false);
714+
assert_eq!(batch.seq_len, 32);
715+
assert_eq!(batch.max_ply, 32); // max_ply == seq_len when no outcome
716+
assert_eq!(batch.input_ids.len(), 4 * 32);
717+
assert_eq!(batch.move_ids.len(), 4 * 32);
718+
}
719+
696720
#[test]
697721
fn test_clm_batch_shift_by_one() {
698722
// Verify targets[t] == input_ids[t+1] for all t < seq_len-1
699-
let batch = generate_clm_batch(4, 64, 42, false, 0.0);
723+
let batch = generate_clm_batch(4, 64, 42, false, 0.0, true);
700724
for b in 0..4 {
701725
let row = b * 64;
702726
for t in 0..63 {
@@ -708,7 +732,7 @@ mod tests {
708732

709733
#[test]
710734
fn test_clm_batch_loss_mask_covers_outcome_and_moves() {
711-
let batch = generate_clm_batch(4, 64, 42, false, 0.0);
735+
let batch = generate_clm_batch(4, 64, 42, false, 0.0, true);
712736
for b in 0..4 {
713737
let gl = batch.game_lengths[b] as usize;
714738
let row = b * 64;
@@ -726,7 +750,7 @@ mod tests {
726750
#[test]
727751
fn test_clm_batch_move_ids_copied_correctly() {
728752
// move_ids in CLMBatch is the raw game moves (seq_len-1 wide)
729-
let batch = generate_clm_batch(2, 32, 42, false, 0.0);
753+
let batch = generate_clm_batch(2, 32, 42, false, 0.0, true);
730754
for b in 0..2 {
731755
let gl = batch.game_lengths[b] as usize;
732756
let max_ply = 31;
@@ -794,16 +818,16 @@ mod tests {
794818

795819
#[test]
796820
fn test_generate_clm_batch_mate_boost_deterministic() {
797-
let b1 = generate_clm_batch(4, 64, 42, false, 0.5);
798-
let b2 = generate_clm_batch(4, 64, 42, false, 0.5);
821+
let b1 = generate_clm_batch(4, 64, 42, false, 0.5, true);
822+
let b2 = generate_clm_batch(4, 64, 42, false, 0.5, true);
799823
assert_eq!(b1.input_ids, b2.input_ids);
800824
assert_eq!(b1.game_lengths, b2.game_lengths);
801825
}
802826

803827
#[test]
804828
fn test_clm_batch_discard_no_ply_limit_outcomes() {
805829
// With discard_ply_limit=true, outcome token at pos 0 is never PLY_LIMIT (4277)
806-
let batch = generate_clm_batch(8, 40, 42, true, 0.0);
830+
let batch = generate_clm_batch(8, 40, 42, true, 0.0, true);
807831
for b in 0..8 {
808832
let outcome = batch.input_ids[b * 40] as u16;
809833
assert_ne!(outcome, vocab::PLY_LIMIT,

0 commit comments

Comments
 (0)