Skip to content

Commit e73f6ff

Browse files
Fix legal grid alignment, add game completion eval, benchmark improvements (#54)
* Fix legal grid off-by-one, add game completion eval, benchmark improvements Legal grid alignment fix: - legal_grid from compute_legal_move_masks is aligned with move_ids (legal moves at position *before* each move), but the trainer checks it against targets which are shifted by one (target[ply] = move_ids[ply+1]). Shift the grid by one ply in create_validation_set so it aligns with targets. This was causing legal_move_rate to always report 0%. Game completion eval: - New compute_game_completion() walks each game ply-by-ply checking whether the model's argmax prediction is legal. Reports: game_completion_rate (fraction of games without any illegal move), avg_pct_completion (mean fraction completed before forfeit), avg_plies_to_forfeit. - Computed on 64 val games at each eval_interval using dense token masks. Benchmark improvements: - CPU/RAM reporting now checks cgroup limits (v1 and v2) before falling back to /proc, so containers report their actual allocation instead of the host's full resources. - Default warmup iterations bumped from 3 to 10 — torch.compile needs more iterations to fully optimize, inflating timed results otherwise. Theoretical ceiling script: - Add --max-ply flag (was hardcoded to 255). * Fix test to use targets as ground-truth preds after legal grid shift The legal grid in create_validation_set is now shifted by one ply to align with targets. The test was using input_ids as predictions, which matched the old unshifted grid. Switch to targets. * Address PR review feedback - Remove dead gc_targets variable and unused n_checked counter - Rename avg_plies_to_forfeit → avg_plies_completed (completed games contribute their full game_length to the average) - Free all GPU tensors in game completion eval cleanup - Move chess_engine import to top of trainer.py - Extract shift_legal_mask() into pawn/data.py to deduplicate the np.roll + zero-fill pattern between data.py and trainer.py - Use math.ceil for fractional CPU counts in cgroup detection
1 parent 46a8c3e commit e73f6ff

File tree

5 files changed

+233
-30
lines changed

5 files changed

+233
-30
lines changed

pawn/data.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,22 @@ def _watchdog():
240240
step += 1
241241

242242

243+
def shift_legal_mask(mask: np.ndarray) -> np.ndarray:
244+
"""Shift a legal move mask forward by one ply to align with CLM targets.
245+
246+
The engine's legal masks are indexed by position *before* each move:
247+
mask[ply] = legal moves at position ply. But CLM targets[ply] is the
248+
*next* move (= move_ids[ply+1]), so we need legal moves at position
249+
ply+1. This rolls the mask by -1 along the ply axis and zeros the
250+
last entry (no next move at the final ply).
251+
252+
Works for any mask shape (B, T, ...).
253+
"""
254+
shifted = np.roll(mask, -1, axis=1)
255+
shifted[:, -1] = 0
256+
return shifted
257+
258+
243259
def create_validation_set(
244260
n_games: int, max_ply: int, seed: int,
245261
discard_ply_limit: bool = False,
@@ -266,9 +282,11 @@ def create_validation_set(
266282
"loss_mask": torch.from_numpy(loss_mask),
267283
}
268284

269-
# Compute legal move masks for evaluating legal move rate
285+
# Compute legal move masks for evaluating legal move rate.
286+
# Shift by one ply so legal_grid[ply] aligns with targets[ply]
287+
# (see shift_legal_mask docstring).
270288
legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
271-
batch["legal_grid"] = torch.from_numpy(legal_grid).long()
289+
batch["legal_grid"] = torch.from_numpy(shift_legal_mask(legal_grid)).long()
272290
batch["game_lengths"] = torch.from_numpy(game_lengths).long()
273291

274292
if no_outcome and prepend_outcome:

pawn/trainer.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
import time
1515
from datetime import datetime, timezone
1616

17+
import numpy as np
1718
import psutil
1819
import torch
1920
import torch.nn as nn
2021
import torch.nn.functional as F
2122
from torch.utils.data import DataLoader
2223

24+
import chess_engine as engine
2325
from pawn.config import CLMConfig, TrainingConfig
2426
from pawn.model import PAWNCLM
25-
from pawn.data import CLMDataset, create_validation_set
27+
from pawn.data import CLMDataset, create_validation_set, shift_legal_mask
2628
from pawn.logging import MetricsLogger
2729

2830
from pawn.data_utils import unpack_grid
@@ -239,6 +241,68 @@ def compute_legal_move_rate_from_preds(
239241

240242

241243

244+
def compute_game_completion(
245+
preds: torch.Tensor,
246+
legal_mask: torch.Tensor,
247+
loss_mask: torch.Tensor,
248+
game_lengths: torch.Tensor,
249+
) -> dict[str, float]:
250+
"""Measure how often the model gets through a full game without illegal moves.
251+
252+
For each game, walks ply-by-ply checking whether the argmax prediction is
253+
legal. The first illegal move is a "forfeit".
254+
255+
Args:
256+
preds: (B, T) argmax token predictions (aligned with targets)
257+
legal_mask: (B, T, V) bool — legal token mask (shifted to align with targets)
258+
loss_mask: (B, T) bool — which positions are valid
259+
game_lengths: (B,) int — number of valid plies per game
260+
261+
Returns dict with:
262+
game_completion_rate: fraction of games with zero illegal moves
263+
avg_pct_completion: mean fraction of game completed before forfeit
264+
avg_plies_completed: mean plies completed before first illegal move.
265+
Games with no illegal moves contribute their full game_length.
266+
"""
267+
B, T = preds.shape
268+
269+
with torch.no_grad():
270+
n_complete = 0
271+
pct_completions = []
272+
plies_completed = []
273+
274+
for b in range(B):
275+
gl = min(int(game_lengths[b].item()), T)
276+
forfeit_ply = -1
277+
for p in range(gl):
278+
if not loss_mask[b, p]:
279+
continue
280+
# Skip plies with no legal moves (end-of-game padding)
281+
if not legal_mask[b, p].any():
282+
continue
283+
token = int(preds[b, p].item())
284+
if token < legal_mask.shape[2] and not legal_mask[b, p, token]:
285+
forfeit_ply = p
286+
break
287+
elif token >= legal_mask.shape[2]:
288+
forfeit_ply = p
289+
break
290+
291+
if forfeit_ply < 0:
292+
n_complete += 1
293+
pct_completions.append(1.0)
294+
plies_completed.append(float(gl))
295+
else:
296+
pct_completions.append(forfeit_ply / gl if gl > 0 else 0.0)
297+
plies_completed.append(float(forfeit_ply))
298+
299+
return {
300+
"game_completion_rate": n_complete / B if B > 0 else 0.0,
301+
"avg_pct_completion": sum(pct_completions) / len(pct_completions) if pct_completions else 0.0,
302+
"avg_plies_completed": sum(plies_completed) / len(plies_completed) if plies_completed else 0.0,
303+
}
304+
305+
242306
def _get_grad_norm(model: nn.Module) -> float:
243307
grads = [p.grad.data for p in model.parameters() if p.grad is not None]
244308
if not grads:
@@ -513,6 +577,39 @@ def evaluate(self) -> dict[str, float]:
513577
torch.cuda.empty_cache()
514578
avg = {f"val/{k}": v / n_batches for k, v in total_metrics.items()}
515579
avg["val/perplexity"] = math.exp(min(avg["val/loss"], 20.0))
580+
581+
# Game completion eval: can the model get through entire games
582+
# without picking an illegal move? Uses a small subset to avoid
583+
# materializing a large dense (B, T, vocab) token mask.
584+
if "game_lengths" in self.val_data:
585+
gc_n = min(64, n)
586+
gc_input = self.val_data["input_ids"][:gc_n].to(self.device)
587+
gc_loss_mask = self.val_data["loss_mask"][:gc_n].to(self.device)
588+
gc_game_lengths = self.val_data["game_lengths"][:gc_n].to(self.device)
589+
move_ids = self.val_data["input_ids"][:gc_n].numpy().astype(np.int16)
590+
gl_np = self.val_data["game_lengths"][:gc_n].numpy().astype(np.int16)
591+
vocab_size = self.model_cfg.vocab_size
592+
593+
with torch.no_grad():
594+
with torch.amp.autocast(self.device, enabled=self.cfg.use_amp):
595+
hidden = model.forward_eval(gc_input, gc_loss_mask)
596+
gc_logits = model.lm_head(hidden)
597+
gc_preds = gc_logits.argmax(dim=-1)
598+
599+
legal_tokens = engine.compute_legal_token_masks(move_ids, gl_np, vocab_size)
600+
legal_mask_t = torch.from_numpy(
601+
shift_legal_mask(legal_tokens)
602+
).to(self.device)
603+
604+
gc = compute_game_completion(gc_preds, legal_mask_t, gc_loss_mask, gc_game_lengths)
605+
avg["val/game_completion_rate"] = gc["game_completion_rate"]
606+
avg["val/avg_pct_completion"] = gc["avg_pct_completion"]
607+
avg["val/avg_plies_completed"] = gc["avg_plies_completed"]
608+
609+
del gc_input, gc_loss_mask, gc_game_lengths, legal_mask_t, gc_logits, gc_preds
610+
if self.device != "cpu" and torch.cuda.is_available():
611+
torch.cuda.empty_cache()
612+
516613
return avg
517614

518615
def train(self):
@@ -602,6 +699,11 @@ def _graceful_exit(signum, frame):
602699
val_msg += f" | legal {val_metrics['val/legal_move_rate']:.3f}"
603700
if "val/late_legal_move_rate" in val_metrics:
604701
val_msg += f" | late_legal {val_metrics['val/late_legal_move_rate']:.3f}"
702+
if "val/game_completion_rate" in val_metrics:
703+
val_msg += (
704+
f" | complete {val_metrics['val/game_completion_rate']:.3f}"
705+
f" | avg_ply {val_metrics['val/avg_plies_completed']:.0f}"
706+
)
605707

606708
# Compound early stopping
607709
extra_log: dict[str, object] = {}

scripts/benchmark.py

Lines changed: 102 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,78 @@ def _collect_cpu_cache() -> dict[str, str]:
10931093
return cache
10941094

10951095

1096+
def _cgroup_cpu_count() -> int | None:
1097+
"""Return container CPU limit from cgroups, or None if unconstrained."""
1098+
import math as _math
1099+
# cgroup v2: cpu.max contains "quota period" (e.g. "200000 100000" = 2 CPUs)
1100+
try:
1101+
text = Path("/sys/fs/cgroup/cpu.max").read_text().strip()
1102+
quota_s, period_s = text.split()
1103+
if quota_s != "max":
1104+
return max(1, _math.ceil(int(quota_s) / int(period_s)))
1105+
except (OSError, ValueError):
1106+
pass
1107+
1108+
# cgroup v1: cpu.cfs_quota_us / cpu.cfs_period_us
1109+
try:
1110+
quota = int(Path("/sys/fs/cgroup/cpu/cpu.cfs_quota_us").read_text().strip())
1111+
if quota > 0:
1112+
period = int(Path("/sys/fs/cgroup/cpu/cpu.cfs_period_us").read_text().strip())
1113+
return max(1, _math.ceil(quota / period))
1114+
except (OSError, ValueError):
1115+
pass
1116+
1117+
# cpuset: count the CPUs in the effective cpuset
1118+
for p in ("/sys/fs/cgroup/cpuset.cpus.effective", # v2
1119+
"/sys/fs/cgroup/cpuset/cpuset.cpus"): # v1
1120+
try:
1121+
text = Path(p).read_text().strip()
1122+
if text:
1123+
# Parse ranges like "0-3,8-11" → count individual CPUs
1124+
count = 0
1125+
for part in text.split(","):
1126+
if "-" in part:
1127+
lo, hi = part.split("-", 1)
1128+
count += int(hi) - int(lo) + 1
1129+
else:
1130+
count += 1
1131+
return count
1132+
except (OSError, ValueError):
1133+
pass
1134+
1135+
return None
1136+
1137+
1138+
def _cgroup_memory_bytes() -> int | None:
1139+
"""Return container memory limit from cgroups, or None if unconstrained."""
1140+
# cgroup v2
1141+
try:
1142+
text = Path("/sys/fs/cgroup/memory.max").read_text().strip()
1143+
if text != "max":
1144+
return int(text)
1145+
except (OSError, ValueError):
1146+
pass
1147+
1148+
# cgroup v1
1149+
try:
1150+
limit = int(Path("/sys/fs/cgroup/memory/memory.limit_in_bytes")
1151+
.read_text().strip())
1152+
# Kernel uses a huge sentinel (~2^63) when unconstrained
1153+
if limit < 2**62:
1154+
return limit
1155+
except (OSError, ValueError):
1156+
pass
1157+
1158+
return None
1159+
1160+
10961161
def _collect_system_info() -> dict:
1097-
"""Collect CPU, RAM, and cache info."""
1162+
"""Collect CPU, RAM, and cache info.
1163+
1164+
In containers (RunPod, Docker), /proc/cpuinfo and /proc/meminfo report
1165+
host-level resources. We check cgroup limits first and prefer those
1166+
when they indicate a constrained environment.
1167+
"""
10981168
import multiprocessing
10991169

11001170
cpu_name = ""
@@ -1122,25 +1192,34 @@ def _collect_system_info() -> dict:
11221192
except (OSError, ValueError):
11231193
pass # keep /proc/cpuinfo MHz if available
11241194

1125-
try:
1126-
cpu_count = len(os.sched_getaffinity(0))
1127-
except (AttributeError, OSError):
1128-
cpu_count = multiprocessing.cpu_count() or 0
1129-
1130-
# System RAM
1131-
ram_gb = 0.0
1132-
try:
1133-
import psutil
1134-
ram_gb = psutil.virtual_memory().total / (1024**3)
1135-
except ImportError:
1195+
# CPU count: prefer cgroup limit over host-visible CPUs
1196+
cg_cpus = _cgroup_cpu_count()
1197+
if cg_cpus is not None:
1198+
cpu_count = cg_cpus
1199+
else:
11361200
try:
1137-
with open("/proc/meminfo") as f:
1138-
for line in f:
1139-
if line.startswith("MemTotal:"):
1140-
ram_gb = int(line.split()[1]) / (1024**2)
1141-
break
1142-
except OSError:
1143-
pass
1201+
cpu_count = len(os.sched_getaffinity(0))
1202+
except (AttributeError, OSError):
1203+
cpu_count = multiprocessing.cpu_count() or 0
1204+
1205+
# System RAM: prefer cgroup limit over host total
1206+
cg_mem = _cgroup_memory_bytes()
1207+
if cg_mem is not None:
1208+
ram_gb = cg_mem / (1024**3)
1209+
else:
1210+
ram_gb = 0.0
1211+
try:
1212+
import psutil
1213+
ram_gb = psutil.virtual_memory().total / (1024**3)
1214+
except ImportError:
1215+
try:
1216+
with open("/proc/meminfo") as f:
1217+
for line in f:
1218+
if line.startswith("MemTotal:"):
1219+
ram_gb = int(line.split()[1]) / (1024**2)
1220+
break
1221+
except OSError:
1222+
pass
11441223

11451224
info: dict = {
11461225
"python": sys.version.split()[0],
@@ -1411,8 +1490,10 @@ def main():
14111490
# Iteration control
14121491
parser.add_argument("--n-iter", type=int, default=10,
14131492
help="Timed iterations per benchmark (default: 10)")
1414-
parser.add_argument("--n-warmup", type=int, default=3,
1415-
help="Warmup iterations per benchmark (default: 3)")
1493+
parser.add_argument("--n-warmup", type=int, default=10,
1494+
help="Warmup iterations per benchmark (default: 10). "
1495+
"torch.compile may need 5-10+ steps to fully optimize; "
1496+
"too few warmup steps inflates timed results.")
14161497

14171498
# Output
14181499
parser.add_argument("--json", type=str, default=None,

scripts/compute_theoretical_ceiling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def main():
7777
parser.add_argument("--output", type=str, default="data/theoretical_ceiling.json")
7878
parser.add_argument("--model-accuracy", type=float, default=None,
7979
help="Model top-1 accuracy to compute adjusted score")
80+
parser.add_argument("--max-ply", type=int, default=255,
81+
help="Maximum game length in plies (default: 255)")
8082
parser.add_argument("--bootstrap", type=int, default=2000,
8183
help="Number of bootstrap resamples for CIs (0 to skip)")
8284
args = parser.parse_args()
@@ -116,7 +118,7 @@ def main():
116118
bt = time.time()
117119
result = engine.compute_accuracy_ceiling(
118120
n_games=batch_n,
119-
max_ply=255,
121+
max_ply=args.max_ply,
120122
n_rollouts=args.rollouts,
121123
sample_rate=args.sample_rate,
122124
seed=batch_seed,

tests/model/test_512_token.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,13 @@ class TestPlyRangeFilter:
378378
def setup(self):
379379
"""Generate val data with known legal grids at seq_len=64.
380380
381-
Uses ``input_ids`` as predictions — input_ids[p] is the actual move
382-
at ply p (no outcome prefix), aligning with legal_grid[p]. This gives
383-
a legal rate of ~1.0 for move positions (the ground truth is always legal).
381+
Uses ``targets`` as predictions — legal_grid is shifted by one ply
382+
in create_validation_set to align with targets (target[p] is the move
383+
at ply p+1). This gives a legal rate of ~1.0 for move positions.
384384
"""
385385
val = create_validation_set(n_games=16, max_ply=64, seed=42)
386-
# input_ids[p] = the move actually played at ply p → always legal
387-
val["preds"] = val["input_ids"].clone()
386+
# targets[p] = the next move (ply p+1), aligned with legal_grid[p]
387+
val["preds"] = val["targets"].clone()
388388
return val
389389

390390
@pytest.mark.integration

0 commit comments

Comments
 (0)