diff --git a/pawn/cotrain.py b/pawn/cotrain.py new file mode 100644 index 0000000..b670927 --- /dev/null +++ b/pawn/cotrain.py @@ -0,0 +1,612 @@ +"""Co-training: train multiple model variants on shared data batches. + +Extracted from ``scripts/train_all.py`` so the lab MCP server and the CLI +script share the same implementation. +""" + +from __future__ import annotations + +import json +import math +import os +import shutil +import signal +import sys +import time +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from pawn.checkpoint import load_pretrain_checkpoint, save_pretrain_checkpoint, push_checkpoint_to_hf +from pawn.config import CLMConfig, LegacyVocab, TrainingConfig +from pawn.data import CLMDataset, create_validation_set +from pawn.gpu import configure_gpu +from pawn.logging import MetricsLogger, random_slug +from pawn.model import PAWNCLM +from pawn.run_config import CotrainConfig, CotrainVariant + + +# --------------------------------------------------------------------------- +# Per-model state +# --------------------------------------------------------------------------- + + +class ModelSlot: + """Holds everything needed to train and checkpoint one model variant.""" + + def __init__( + self, + name: str, + model_cfg: CLMConfig, + train_cfg: TrainingConfig, + device: str, + hf_repo: str | None, + shm_checkpoints: bool = False, + slug: str = "", + resume_path: str | None = None, + ): + self.name = name + self.slug = slug + self.model_cfg = model_cfg + self.train_cfg = train_cfg + self.device = device + self.hf_repo = hf_repo + self.shm_checkpoints = shm_checkpoints + + self.model = PAWNCLM(model_cfg).to(device) + param_count = sum(p.numel() for p in self.model.parameters()) + print(f" {name}: {param_count:,} params ({model_cfg.d_model}d/{model_cfg.n_layers}L)") + + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=train_cfg.lr, + weight_decay=train_cfg.weight_decay, + ) + + from pawn.trainer import CosineWithWarmup + self.scheduler = CosineWithWarmup( + self.optimizer, + warmup_steps=train_cfg.warmup_steps, + total_steps=train_cfg.total_steps, + ) + + self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp) + + # Logger (creates run directory) + self.logger = MetricsLogger( + train_cfg.log_dir, run_prefix="run", device=device, + slug=slug, suffix=name, + ) + self.run_dir = str(self.logger.run_dir) + self.jsonl_path = str(self.logger.metrics_path) + + # Checkpoint directory: /dev/shm if requested, else under run_dir + if shm_checkpoints: + self.checkpoint_dir = f"/dev/shm/pawn_checkpoints/{name}" + else: + self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) + + self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None + self.global_step = 0 + self.best_val_step = 0 + self.best_val_loss = float("inf") + self.patience_counter = 0 + self.stopped = False + + # Background HF push (one thread per slot, so pushes don't block training) + from concurrent.futures import ThreadPoolExecutor + self._hf_push_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"hf-{name}") + self._hf_push_future = None + + # Resume from checkpoint if requested + if resume_path: + meta = load_pretrain_checkpoint( + resume_path, self.model, self.optimizer, self.scheduler, + self.scaler, device=device, + ) + self.global_step = meta["global_step"] + if meta.get("best_val_loss") is not None: + self.best_val_loss = meta["best_val_loss"] + if meta.get("patience_counter") is not None: + self.patience_counter = meta["patience_counter"] + print(f" [{name}] Resumed from step {self.global_step} " + f"(checkpoint: {resume_path})") + + self.logger.log_config( + model=model_cfg.__dict__, + training=train_cfg.__dict__, + param_count=param_count, + formulation="clm", + multi_model=True, + variant=name, + ) + self.logger.write_config_json( + model=model_cfg.__dict__, + training=train_cfg.__dict__, + param_count=param_count, + formulation="clm", + multi_model=True, + variant=name, + ) + + def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Forward + backward. Returns raw GPU tensor metrics (no .item() sync).""" + self.model.train() + input_ids = batch["input_ids"].to(self.device, non_blocking=True) + targets = batch["targets"].to(self.device, non_blocking=True) + loss_mask = batch["loss_mask"].to(self.device, non_blocking=True) + + with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp): + loss, metrics = self.model.forward_train(input_ids, loss_mask, targets) + + self.scaler.scale(loss).backward() + return metrics + + def optimizer_step(self) -> float: + self.scaler.unscale_(self.optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.train_cfg.max_grad_norm + ).item() + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + self.scheduler.step() + return grad_norm + + def _unwrapped_model(self) -> PAWNCLM: + """Return the unwrapped model (strips torch.compile wrapper).""" + m: Any = self.model + while hasattr(m, '_orig_mod'): + m = m._orig_mod + return m # type: ignore[return-value] + + def save_checkpoint(self): + path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}") + save_pretrain_checkpoint( + path, self._unwrapped_model(), self.optimizer, self.scheduler, self.scaler, + self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__, # type: ignore[arg-type] + ) + print(f" [{self.name}] Checkpoint saved: {path}") + + if self.hf_repo and self.hf_branch: + self._push_to_hf_async(path, self.global_step) + + def _push_to_hf_async(self, ckpt_path: str, step: int): + """Push checkpoint to HuggingFace in a background thread.""" + # Wait for any previous push to finish before starting a new one + if self._hf_push_future is not None: + self._hf_push_future.result() # blocks until previous push completes + + assert self.hf_repo is not None and self.hf_branch is not None + hf_repo, hf_branch = self.hf_repo, self.hf_branch + + def _push(): + try: + push_checkpoint_to_hf( + ckpt_path, hf_repo, hf_branch, + metrics_path=self.jsonl_path, step=step, + ) + print(f" [{self.name}] Pushed to HF: {hf_repo}@{hf_branch}") + + # On /dev/shm, clean up old checkpoints after successful push. + # Keep the latest (just saved) and the best (for post-training evals). + if self.shm_checkpoints: + keep = {Path(ckpt_path).name, f"step_{self.best_val_step:08d}"} + for old in sorted(Path(self.checkpoint_dir).glob("step_*")): + if old.name not in keep: + shutil.rmtree(old, ignore_errors=True) + except Exception as e: + print(f" [{self.name}] WARNING: HF push failed: {e}") + + self._hf_push_future = self._hf_push_pool.submit(_push) + + def push_metrics_to_hf(self): + """Push just metrics.jsonl to HF (lightweight, no checkpoint).""" + if not self.hf_repo or not self.hf_branch: + return + + hf_repo, hf_branch = self.hf_repo, self.hf_branch + + def _push_metrics(): + try: + from huggingface_hub import HfApi + api = HfApi() + api.create_branch(hf_repo, repo_type="model", + branch=hf_branch, exist_ok=True) + api.upload_file( + path_or_fileobj=self.jsonl_path, + path_in_repo="metrics.jsonl", + repo_id=hf_repo, + repo_type="model", + revision=hf_branch, + commit_message=f"Metrics through step {self.global_step}", + ) + except Exception as e: + print(f" [{self.name}] WARNING: metrics push failed: {e}") + + # Fire and forget on the push pool (queued behind any checkpoint push) + self._hf_push_pool.submit(_push_metrics) + + def wait_for_push(self): + """Block until any in-flight HF push completes.""" + if self._hf_push_future is not None: + self._hf_push_future.result() + self._hf_push_future = None + + @torch.no_grad() + def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]: + from pawn.trainer import compute_legal_move_rate_from_preds + + self.model.eval() + n = val_data["input_ids"].shape[0] + batch_size = self.train_cfg.batch_size + total_metrics: dict[str, float] = {} + n_batches = 0 + has_legal = "legal_grid" in val_data + + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + input_ids = val_data["input_ids"][start:end].to(self.device, non_blocking=True) + targets = val_data["targets"][start:end].to(self.device, non_blocking=True) + loss_mask = val_data["loss_mask"][start:end].to(self.device, non_blocking=True) + + with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp): + # Get hidden states without materializing full (B,T,V) logits + hidden = self.model.forward_eval(input_ids, loss_mask) + + # Sparse projection: only valid positions through lm_head + valid_hidden = hidden[loss_mask] + valid_logits = self.model.lm_head(valid_hidden) + valid_targets = targets[loss_mask] + + loss = F.cross_entropy(valid_logits, valid_targets) + accuracy = (valid_logits.argmax(-1) == valid_targets).float().mean().item() + metrics: dict[str, float] = {"loss": loss.item(), "accuracy": accuracy} + + # Top-5 accuracy + top5 = valid_logits.topk(5, dim=-1).indices + metrics["top5_accuracy"] = ( + (top5 == valid_targets.unsqueeze(-1)).any(dim=-1).float().mean().item() + ) + + # Legal move rate: reuse already-computed valid_logits argmax + if has_legal: + legal_grid = val_data["legal_grid"][start:end].to(self.device, non_blocking=True) + game_lengths = val_data["game_lengths"][start:end].to(self.device, non_blocking=True) + preds = torch.zeros_like(loss_mask, dtype=torch.long) + preds[loss_mask] = valid_logits.argmax(dim=-1) + metrics["legal_move_rate"] = compute_legal_move_rate_from_preds( + preds, legal_grid, loss_mask, game_lengths, + n_actions=self.model.embed.n_actions, + ) + + for k, v in metrics.items(): + total_metrics[k] = total_metrics.get(k, 0.0) + v + n_batches += 1 + + avg = {f"val/{k}": v / n_batches for k, v in total_metrics.items()} + avg["val/perplexity"] = math.exp(min(avg["val/loss"], 20.0)) + return avg + + def close(self): + self.wait_for_push() + self._hf_push_pool.shutdown(wait=True) + self.logger.close() + + +# --------------------------------------------------------------------------- +# Variant config builder +# --------------------------------------------------------------------------- + + +def _build_variant_configs( + variant_spec: CotrainVariant, + shared: CotrainConfig, + device: str, + scaled_lr: float, +) -> tuple[CLMConfig, TrainingConfig]: + """Build internal CLMConfig + TrainingConfig for one variant.""" + variant_factory = { + "small": CLMConfig.small, + "base": CLMConfig.base, + "large": CLMConfig.large, + "toy": CLMConfig.toy, + } + model_cfg = variant_factory[variant_spec.variant]() + + # Architecture overrides from the variant spec + if variant_spec.d_model is not None: + model_cfg.d_model = variant_spec.d_model + if variant_spec.n_layers is not None: + model_cfg.n_layers = variant_spec.n_layers + if variant_spec.n_heads is not None: + model_cfg.n_heads = variant_spec.n_heads + if variant_spec.d_ff is not None: + model_cfg.d_ff = variant_spec.d_ff + model_cfg.max_seq_len = variant_spec.max_seq_len + + if variant_spec.legacy_vocab: + model_cfg.vocab_size = LegacyVocab.VOCAB_SIZE + model_cfg.max_seq_len = 256 + + train_cfg = TrainingConfig() + train_cfg.lr = scaled_lr + train_cfg.total_steps = shared.total_steps or train_cfg.total_steps + train_cfg.batch_size = shared.batch_size + train_cfg.num_workers = shared.num_workers + train_cfg.device = device + train_cfg.log_dir = shared.log_dir or train_cfg.log_dir + train_cfg.log_interval = shared.log_interval + if shared.eval_interval is not None: + train_cfg.eval_interval = shared.eval_interval + train_cfg.checkpoint_interval = shared.checkpoint_interval + train_cfg.discard_ply_limit = shared.discard_ply_limit + train_cfg.no_outcome_token = shared.no_outcome_token + train_cfg.mate_boost = shared.mate_boost + train_cfg.use_wandb = shared.wandb + train_cfg.use_amp = shared.amp_dtype != "none" + train_cfg.max_ply = model_cfg.max_seq_len + train_cfg.weight_decay = shared.weight_decay + train_cfg.max_grad_norm = shared.max_grad_norm + train_cfg.pause_after_steps = shared.pause_after_steps + if shared.warmup_steps is not None: + train_cfg.warmup_steps = shared.warmup_steps + elif shared.total_steps is not None: + train_cfg.warmup_steps = int(shared.warmup_frac * shared.total_steps) + else: + train_cfg.warmup_steps = int(shared.warmup_frac * train_cfg.total_steps) + train_cfg.val_games = shared.val_games + + return model_cfg, train_cfg + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + + +def run_cotrain(config: CotrainConfig) -> list[ModelSlot]: + """Run co-training from a ``CotrainConfig``. Returns the final slots.""" + device = config.device + if device == "cuda": + if not torch.cuda.is_available(): + print("ERROR: CUDA not available", file=sys.stderr) + sys.exit(1) + gpu_cfg = configure_gpu( + device, + no_compile=config.no_compile, + no_amp=(config.amp_dtype == "none"), + sdpa_math=config.sdpa_math, + ) + import pawn.model as model_module + if gpu_cfg.get("sdpa_backend"): + model_module.SDPA_BACKEND = gpu_cfg["sdpa_backend"] + + total_steps = config.total_steps or 100_000 + + # Linear LR scaling: lr = base_lr * (batch_size / base_batch_size) + base_batch_size = 256 + base_lr = config.lr + scaled_lr = base_lr * (config.batch_size / base_batch_size) + + slug = random_slug() + variant_names = [v.name for v in config.variants] + + print(f"=== Co-Training [{slug}] ===") + print(f"Device: {device}") + print(f"Batch size: {config.batch_size}") + print(f"Total steps: {total_steps}") + print(f"Variants: {', '.join(variant_names)}") + if config.shm_checkpoints: + print("Checkpoints: /dev/shm (volatile, HF push is durable store)") + if config.no_outcome_token: + print("Outcome token: DISABLED (ablation experiment)") + print(f"LR: {scaled_lr:.2e} (scaled from {base_lr:.2e} for batch {config.batch_size})") + print() + + # Build slots + slots: list[ModelSlot] = [] + for variant_spec in config.variants: + model_cfg, train_cfg = _build_variant_configs( + variant_spec, config, device, scaled_lr, + ) + hf_repo = f"{config.hf_repo}-{variant_spec.name}" if config.hf_repo else None + slots.append(ModelSlot( + variant_spec.name, model_cfg, train_cfg, device, hf_repo, + shm_checkpoints=config.shm_checkpoints, slug=slug, + resume_path=variant_spec.resume, + )) + + # Verify all resumed slots agree on global_step + resumed_steps = {s.global_step for s in slots if s.global_step > 0} + if len(resumed_steps) > 1: + per_slot = {s.name: s.global_step for s in slots} + print(f"ERROR: Resumed variants disagree on global_step: {per_slot}", + file=sys.stderr) + sys.exit(1) + start_step = max(resumed_steps) if resumed_steps else 0 + + # Shared dataset and validation set — use the max_seq_len from the first variant + # All variants must produce compatible sequence lengths for the shared DataLoader. + # Use the maximum max_seq_len across all variants so shorter models can mask off. + max_ply = max(v.max_seq_len for v in config.variants) + any_legacy = any(v.legacy_vocab for v in config.variants) + if any_legacy: + max_ply = 256 + + dataset = CLMDataset( + config.batch_size, max_ply, base_seed=42, + discard_ply_limit=config.discard_ply_limit, + no_outcome=config.no_outcome_token, + ) + + print("\nGenerating shared validation set...") + val_data = create_validation_set( + config.val_games, max_ply, seed=(2**63) - 1, + discard_ply_limit=config.discard_ply_limit, + no_outcome=config.no_outcome_token, + ) + + # Compile models + if device != "cpu" and not config.no_compile: + for slot in slots: + try: + slot.model = torch.compile(slot.model, mode="default") # type: ignore[assignment] + print(f" [{slot.name}] torch.compile enabled") + except Exception: + print(f" [{slot.name}] torch.compile not available") + + loader = DataLoader( + dataset, + batch_size=None, + num_workers=config.num_workers, + pin_memory=(device != "cpu"), + persistent_workers=(config.num_workers > 0), + prefetch_factor=2 if config.num_workers > 0 else None, + multiprocessing_context="spawn" if config.num_workers > 0 else None, + ) + + # Signal handling + _shutdown_requested = False + _shutdown_signal = None + + def _graceful_exit(signum, frame): + nonlocal _shutdown_requested, _shutdown_signal + _shutdown_requested = True + _shutdown_signal = signum + + signal.signal(signal.SIGTERM, _graceful_exit) + signal.signal(signal.SIGINT, _graceful_exit) + + # Training loop + patience = config.patience or 0 + global_step = start_step + step_start = time.time() + + print(f"\nStarting training from step {global_step}", flush=True) + for slot in slots: + print(f" [{slot.name}] JSONL: {slot.jsonl_path}", flush=True) + print() + + active_slots = [s for s in slots if not s.stopped] + log_interval = config.log_interval + eval_interval = slots[0].train_cfg.eval_interval + checkpoint_interval = config.checkpoint_interval + + for batch in loader: + # Forward + backward + optimizer step per model so CUDA can overlap + # Adam updates (memory-bound) with the next model's forward (compute-bound) + all_metrics: dict[str, dict[str, torch.Tensor]] = {} + all_grad_norms: dict[str, float] = {} + for slot in active_slots: + metrics = slot.train_step(batch) + all_metrics[slot.name] = metrics + gn = slot.optimizer_step() + all_grad_norms[slot.name] = gn + + global_step += 1 + for slot in slots: + slot.global_step = global_step + + step_time = time.time() - step_start + games_per_sec = config.batch_size / step_time + + # Logging — .item() sync only at log intervals + if global_step % log_interval == 0: + active_names = ", ".join(s.name for s in active_slots) + print(f"step {global_step:>7d} | {games_per_sec:.0f} g/s | {step_time:.2f}s | active: {active_names}", flush=True) + for slot in active_slots: + m = all_metrics[slot.name] + loss_val = m['loss'].item() + acc_val = m['accuracy'].item() + gn = all_grad_norms[slot.name] + lr = slot.scheduler.get_lr() + print(f" {slot.name:>5s}: loss {loss_val:.4f} | acc {acc_val:.3f} | " + f"lr {lr:.2e} | gn {gn:.2f}", flush=True) + + slot.logger.log_train( + step=global_step, + lr=lr, grad_norm=gn, + step_time=step_time, games_per_sec=games_per_sec, + **{"train/loss": loss_val, "train/accuracy": acc_val}, + ) + + # Eval + if global_step % eval_interval == 0: + for slot in active_slots: + val_metrics = slot.evaluate(val_data) + print(f" {slot.name:>5s} val: loss {val_metrics['val/loss']:.4f} | " + f"acc {val_metrics['val/accuracy']:.3f}", flush=True) + # Track best for eval, /dev/shm cleanup, and patience + vl = val_metrics["val/loss"] + if vl < slot.best_val_loss: + slot.best_val_loss = vl + slot.best_val_step = global_step + slot.patience_counter = 0 + else: + slot.patience_counter += 1 + + slot.logger.log_val( + step=global_step, + patience=slot.patience_counter, + best_val_loss=slot.best_val_loss, + best_val_step=slot.best_val_step, + **val_metrics, + ) + + # Per-model early stopping + if patience > 0 and slot.patience_counter >= patience: + print(f" [{slot.name}] Early stopping — no improvement " + f"for {patience} evals (best step {slot.best_val_step})") + slot.stopped = True + slot.save_checkpoint() + + active_slots = [s for s in active_slots if not s.stopped] + + # Push metrics to HF after eval (lightweight, background) + for slot in slots: + slot.push_metrics_to_hf() + + if not active_slots: + print(f"\nAll models stopped at step {global_step}") + break + + # Checkpoint + if global_step % checkpoint_interval == 0: + for slot in active_slots: + slot.save_checkpoint() + + # Done? + if global_step >= total_steps: + print(f"\nTraining complete at step {global_step}") + for slot in active_slots: + slot.save_checkpoint() + break + + # Pause + if config.pause_after_steps and global_step >= config.pause_after_steps: + print(f"\nPause requested at step {global_step}, saving checkpoints...") + for slot in active_slots: + slot.save_checkpoint() + break + + # Graceful shutdown + if _shutdown_requested: + print(f"\nShutdown requested (signal {_shutdown_signal}), " + f"saving checkpoints at step {global_step}...") + for slot in active_slots: + slot.save_checkpoint() + break + + step_start = time.time() + + # Cleanup + for slot in slots: + slot.close() + + print("\nAll done.") + return slots diff --git a/pawn/lab/monitor.py b/pawn/lab/monitor.py index 10ec033..78823bb 100644 --- a/pawn/lab/monitor.py +++ b/pawn/lab/monitor.py @@ -41,9 +41,31 @@ def is_alive(pid: int) -> tuple[bool, int | None]: def read_metrics( trial: Trial, log_dir: Path, - offsets: dict[int, int], + offsets: dict, ) -> None: - """Read new lines from the trial's metrics.jsonl, updating trial in-place.""" + """Read new lines from the trial's metrics.jsonl, updating trial in-place. + + For cotrain trials, discovers multiple per-variant metrics files and + aggregates them to the trial level while tracking per-variant state in + ``trial.variants``. + + ``offsets`` keys are ``int`` (trial_id) for single-variant trials, or + ``(trial_id, variant_name)`` for cotrain per-variant files. + """ + is_cotrain = (trial.config or {}).get("run_type") == "cotrain" + + if is_cotrain: + _read_cotrain_metrics(trial, log_dir, offsets) + else: + _read_single_metrics(trial, log_dir, offsets) + + +def _read_single_metrics( + trial: Trial, + log_dir: Path, + offsets: dict, +) -> None: + """Read metrics for a single-variant (pretrain/adapter) trial.""" # Find run dir if not yet discovered — pick the most recent if trial.run_dir is None: trial_log_dir = log_dir / f"trial_{trial.trial_id:04d}" @@ -116,6 +138,166 @@ def read_metrics( trial.best_accuracy = acc +def _read_cotrain_metrics( + trial: Trial, + log_dir: Path, + offsets: dict, +) -> None: + """Read metrics for a cotrain trial (multiple per-variant JSONL files).""" + trial_log_dir = log_dir / f"trial_{trial.trial_id:04d}" + + # Discover all per-variant metrics files under the trial dir. + # Each variant's MetricsLogger creates a run dir with suffix=variant_name, + # e.g. run_20260410_151230_zesty-osprey_small/metrics.jsonl + metrics_files = list(trial_log_dir.glob("*/metrics.jsonl")) + if not metrics_files: + return + + # Set trial.run_dir to the parent trial dir (not a specific variant) + if trial.run_dir is None: + trial.run_dir = str(trial_log_dir) + + # Initialize variants dict if needed + if trial.variants is None: + trial.variants = {} + + # Extract variant name from the run dir suffix: run_..._/metrics.jsonl + # The MetricsLogger uses suffix=name, producing dirs like + # run_YYYYMMDD_HHMMSS_slug_variantname/ + for mf in metrics_files: + variant_name = _extract_variant_name(mf.parent.name) + if variant_name is None: + continue + + # Initialize this variant's state dict + if variant_name not in trial.variants: + trial.variants[variant_name] = { + "name": variant_name, + "run_dir": str(mf.parent), + "current_step": 0, + "last_train_loss": None, + "last_train_acc": None, + "best_val_loss": None, + "best_val_step": 0, + "best_accuracy": None, + "actual_param_count": None, + "stopped": False, + "steps_per_sec": 0.0, + } + + vs = trial.variants[variant_name] + offset_key = (trial.trial_id, variant_name) + offset = offsets.get(offset_key, 0) + + try: + with open(mf) as f: + f.seek(offset) + new_lines = f.readlines() + offsets[offset_key] = f.tell() + except OSError: + continue + + for line in new_lines: + try: + rec = json.loads(line) + except (json.JSONDecodeError, ValueError): + continue + + rtype = rec.get("type") + if rtype == "config": + ts = rec.get("total_steps") or (rec.get("training") or {}).get("total_steps") + if ts: + trial.total_steps = ts + pc = rec.get("param_count") + if pc is not None: + vs["actual_param_count"] = pc + + elif rtype == "train": + vs["current_step"] = rec.get("step", vs["current_step"]) + loss = rec.get("train/loss") or rec.get("train_loss") + if loss is not None: + vs["last_train_loss"] = loss + train_acc = rec.get("train/accuracy") or rec.get("train_top1") + if train_acc is not None: + vs["last_train_acc"] = train_acc + st = rec.get("step_time") + if st and st > 0: + vs["steps_per_sec"] = 1.0 / st + elif rec.get("elapsed") and vs["current_step"] > 0: + vs["steps_per_sec"] = vs["current_step"] / rec["elapsed"] + + elif rtype == "val": + vl = rec.get("val/loss") or rec.get("val_loss") or rec.get("loss") + if vl is not None and (vs["best_val_loss"] is None or vl < vs["best_val_loss"]): + vs["best_val_loss"] = vl + vs["best_val_step"] = rec.get("step", vs.get("best_val_step", 0)) + acc = (rec.get("val/accuracy") or rec.get("val_top1") + or rec.get("accuracy")) + if acc is not None: + vs["best_accuracy"] = acc + + # Aggregate to trial level + _aggregate_cotrain_metrics(trial) + + +def _extract_variant_name(run_dir_name: str) -> str | None: + """Extract variant name from a run directory name. + + The MetricsLogger creates dirs like ``run_YYYYMMDD_HHMMSS_variantname_slug``. + The layout is: ``run`` _ ``date`` _ ``time`` _ ``variant`` _ ``slug``. + The variant name may itself contain underscores, but the slug (final segment) + never does (it's two hyphenated words like ``calm-crane``). So we rejoin + everything between parts[3] and parts[-1]. + """ + # Expected: run_YYYYMMDD_HHMMSS_variant_slug (at least 5 parts) + parts = run_dir_name.split("_") + if len(parts) < 5 or parts[0] != "run": + return None + # parts[1]=date, parts[2]=time, parts[-1]=slug, parts[3:-1]=variant + return "_".join(parts[3:-1]) + + +def _aggregate_cotrain_metrics(trial: Trial) -> None: + """Aggregate per-variant metrics to the trial level.""" + if not trial.variants: + return + + variants = list(trial.variants.values()) + + # current_step = min across variants (honest ETA — slowest determines progress) + steps = [v["current_step"] for v in variants if v["current_step"] > 0] + if steps: + trial.current_step = min(steps) + + # best_val_loss = min across variants + val_losses = [v["best_val_loss"] for v in variants if v["best_val_loss"] is not None] + if val_losses: + trial.best_val_loss = min(val_losses) + + # best_accuracy = max across variants + accs = [v["best_accuracy"] for v in variants if v["best_accuracy"] is not None] + if accs: + trial.best_accuracy = max(accs) + + # last_train_loss = mean across active variants + losses = [v["last_train_loss"] for v in variants + if v["last_train_loss"] is not None and not v.get("stopped")] + if losses: + trial.last_train_loss = sum(losses) / len(losses) + + # last_train_acc = mean across active variants + accs_train = [v["last_train_acc"] for v in variants + if v.get("last_train_acc") is not None and not v.get("stopped")] + if accs_train: + trial.last_train_acc = sum(accs_train) / len(accs_train) + + # steps_per_sec from any variant (they share the same step timing) + for v in variants: + if v.get("steps_per_sec", 0) > 0: + trial.steps_per_sec = v["steps_per_sec"] + break + + def read_pretrain_val_summary(trial: Trial) -> dict[str, Any] | None: """Scan the trial's metrics.jsonl for the latest pretraining val record and compute a log-linear fit on forfeit rate over the most recent half diff --git a/pawn/lab/runner.py b/pawn/lab/runner.py index 3ab6ea8..1cfcaf2 100644 --- a/pawn/lab/runner.py +++ b/pawn/lab/runner.py @@ -35,14 +35,19 @@ def _validate_config(config: dict[str, Any]) -> dict[str, Any]: """ from pydantic import TypeAdapter - from pawn.run_config import AdapterConfig, PretrainConfig + from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig run_type = config.get("run_type") - if run_type not in ("pretrain", "adapter"): - raise ValueError(f"run_type must be 'pretrain' or 'adapter', got {run_type!r}") - ta = TypeAdapter( - PretrainConfig if run_type == "pretrain" else AdapterConfig - ) + config_cls = { + "pretrain": PretrainConfig, + "adapter": AdapterConfig, + "cotrain": CotrainConfig, + }.get(run_type) # type: ignore[arg-type] + if config_cls is None: + raise ValueError( + f"run_type must be 'pretrain', 'adapter', or 'cotrain', got {run_type!r}" + ) + ta = TypeAdapter(config_cls) return ta.validate_python(config).model_dump() @@ -266,7 +271,11 @@ async def launch( validated = _validate_config(config) cmd = self._build_command(validated, trial_id) - strategy_display = validated.get("strategy") or validated.get("variant", "pretrain") + if validated.get("run_type") == "cotrain": + variant_names = [v["name"] for v in validated.get("variants", [])] + strategy_display = "cotrain:" + "+".join(variant_names) + else: + strategy_display = validated.get("strategy") or validated.get("variant", "pretrain") trial = Trial( trial_id=trial_id, strategy=strategy_display, @@ -296,34 +305,83 @@ async def resume_trial( total_steps: int | None = None, pause_after_steps: int | None = None, ) -> int: - """Resume a completed/failed trial from its best checkpoint.""" + """Resume a completed/failed trial from its best checkpoint. + + For cotrain trials, discovers per-variant checkpoints and sets + the resume path on each variant in the new config. + """ old = self.trials.get(trial_id) if not old: raise RuntimeError(f"Trial {trial_id} not found") if not old.run_dir: raise RuntimeError(f"Trial {trial_id} has no run directory") - ckpt_base = Path(old.run_dir) / "checkpoints" + new_config = dict(old.config) + new_config.pop("pause_after_steps", None) + + if (old.config or {}).get("run_type") == "cotrain": + self._resolve_cotrain_resume(old, new_config) + else: + ckpt_dir = self._find_latest_checkpoint(Path(old.run_dir)) + new_config["resume"] = str(ckpt_dir) + + if total_steps is not None: + new_config["total_steps"] = total_steps + if pause_after_steps is not None: + new_config["pause_after_steps"] = pause_after_steps + + return await self.launch(new_config, tags=old.tags) + + @staticmethod + def _find_latest_checkpoint(run_dir: Path) -> Path: + """Find the latest checkpoint under a run directory. + + Checks for ``best/`` and ``final/`` symlinks first (adapter runs), + then falls back to the highest-numbered ``step_*`` directory + (pretrain/cotrain runs, which don't create best/final symlinks). + """ + ckpt_base = run_dir / "checkpoints" ckpt_dir = ckpt_base / "best" if not ckpt_dir.exists(): ckpt_dir = ckpt_base / "final" if not ckpt_dir.exists(): - # Pretraining uses step_XXXXXXXX naming — pick the highest step step_dirs = sorted(ckpt_base.glob("step_*")) if step_dirs: ckpt_dir = step_dirs[-1] if not ckpt_dir.exists(): - raise RuntimeError(f"No checkpoint found for trial {trial_id}") + raise RuntimeError(f"No checkpoint found under {run_dir}") + return ckpt_dir - new_config = dict(old.config) - new_config.pop("pause_after_steps", None) - new_config["resume"] = str(ckpt_dir) - if total_steps is not None: - new_config["total_steps"] = total_steps - if pause_after_steps is not None: - new_config["pause_after_steps"] = pause_after_steps + def _resolve_cotrain_resume( + self, old: "Trial", new_config: dict[str, Any], + ) -> None: + """Set per-variant resume paths for a cotrain trial.""" + if not old.variants: + raise RuntimeError( + f"Trial {old.trial_id} is cotrain but has no variant state. " + "Cannot determine per-variant checkpoints." + ) - return await self.launch(new_config, tags=old.tags) + # Deep-copy variants list so we can mutate + import copy + variants = copy.deepcopy(new_config.get("variants", [])) + + for v_cfg in variants: + name = v_cfg.get("name") + if name not in old.variants: + raise RuntimeError( + f"Variant '{name}' not found in trial {old.trial_id} state" + ) + vs = old.variants[name] + v_run_dir = vs.get("run_dir") + if not v_run_dir: + raise RuntimeError( + f"Variant '{name}' in trial {old.trial_id} has no run directory" + ) + ckpt_dir = self._find_latest_checkpoint(Path(v_run_dir)) + v_cfg["resume"] = str(ckpt_dir) + + new_config["variants"] = variants def _build_command( self, config: dict[str, Any], trial_id: int, diff --git a/pawn/lab/server.py b/pawn/lab/server.py index 3898285..292942d 100644 --- a/pawn/lab/server.py +++ b/pawn/lab/server.py @@ -44,7 +44,7 @@ async def lab_status(ctx: Context) -> dict[str, Any]: @mcp.tool async def lab_launch(config: dict[str, Any], ctx: Context, tags: list[str] | None = None) -> dict[str, Any]: - """Launch a trial from a RunConfig dict. Use lab_schema to discover all fields. The config must include run_type ('pretrain' or 'adapter'). Optionally pass tags for grouping (e.g. ["phase1", "mate-boost"]).""" + """Launch a trial from a RunConfig dict. Use lab_schema to discover all fields. The config must include run_type ('pretrain', 'adapter', or 'cotrain'). Optionally pass tags for grouping (e.g. ["phase1", "mate-boost"]).""" try: tid = await _runner(ctx).launch(config, tags=tags) return _runner(ctx).trials[tid].to_dict() @@ -105,10 +105,11 @@ async def lab_set_cost(cost_per_hour: float, ctx: Context) -> dict[str, Any]: @mcp.tool async def lab_schema(ctx: Context) -> dict[str, Any]: - """Return the JSON Schema for RunConfig (PretrainConfig and AdapterConfig). Use this to discover all available parameters before calling lab_launch.""" - from pawn.run_config import AdapterConfig, PretrainConfig + """Return the JSON Schema for RunConfig (PretrainConfig, AdapterConfig, CotrainConfig). Use this to discover all available parameters before calling lab_launch.""" + from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig return { "pretrain": PretrainConfig.model_json_schema(), "adapter": AdapterConfig.model_json_schema(), + "cotrain": CotrainConfig.model_json_schema(), } diff --git a/pawn/lab/state.py b/pawn/lab/state.py index 3d84b5a..78bf658 100644 --- a/pawn/lab/state.py +++ b/pawn/lab/state.py @@ -52,6 +52,8 @@ class Trial: # Agent annotations notes: str = "" tags: list[str] = field(default_factory=list) + # Co-training: per-variant state (None for non-cotrain trials) + variants: dict[str, dict[str, Any]] | None = None def eta_seconds(self) -> float | None: if self.steps_per_sec > 0 and self.total_steps > self.current_step: diff --git a/pawn/logging.py b/pawn/logging.py index 245af91..5d8c4f6 100644 --- a/pawn/logging.py +++ b/pawn/logging.py @@ -215,23 +215,18 @@ def write_config_json(self, **kwargs: object) -> Path: def log_train( self, step: int, - epoch: int | None = None, **metrics: object, ) -> None: """Log a training metrics record. Standard fields (pass as kwargs): - lr, grad_norm, loss, accuracy, step_time, games_per_sec, + epoch, lr, grad_norm, loss, accuracy, step_time, games_per_sec, train_loss, train_top1, etc. Adapter-specific fields are passed through as-is: film/gamma_norm_L0, lora/B_norm_q, adapter/up_norm, etc. """ record: dict[str, object] = {"type": "train", "step": step} - if epoch is not None: - record["epoch"] = epoch - - # Normalize common field names: accept both flat and prefixed for k, v in metrics.items(): record[k] = v @@ -245,19 +240,15 @@ def log_train( def log_val( self, step: int, - epoch: int | None = None, **metrics: object, ) -> None: """Log a validation metrics record. Standard fields (pass as kwargs): - loss (or val/loss), accuracy, top5_accuracy, + epoch, loss (or val/loss), accuracy, top5_accuracy, patience, best_val_loss, best_val_step, etc. """ record: dict[str, object] = {"type": "val", "step": step} - if epoch is not None: - record["epoch"] = epoch - for k, v in metrics.items(): record[k] = v diff --git a/pawn/run_config.py b/pawn/run_config.py index 722ca97..31c8d8f 100644 --- a/pawn/run_config.py +++ b/pawn/run_config.py @@ -142,7 +142,55 @@ class AdapterConfig(BaseRunConfig): val_every: int = 1 +class CotrainVariant(BaseModel): + """Per-variant spec within a co-training run. + + Architecture fields override the preset selected by ``variant``. + ``resume`` is set automatically by ``lab_resume`` — not user-facing. + """ + + name: str + variant: Literal["toy", "small", "base", "large"] = "base" + d_model: int | None = None + n_layers: int | None = None + n_heads: int | None = None + d_ff: int | None = None + max_seq_len: int = 512 + legacy_vocab: bool = False + resume: str | None = None + + +class CotrainConfig(BaseRunConfig): + """Co-training multiple model variants on shared data batches.""" + + run_type: Literal["cotrain"] = "cotrain" + variants: list[CotrainVariant] + checkpoint_interval: int = 5000 + shm_checkpoints: bool = False + val_games: int = 512 # override BaseRunConfig's 50K — pretrain uses on-the-fly data + + @model_validator(mode="after") + def _check_cotrain(self) -> "CotrainConfig": + if not self.variants: + raise ValueError("variants must contain at least one entry") + names = [v.name for v in self.variants] + if len(names) != len(set(names)): + raise ValueError(f"variant names must be unique, got {names}") + if self.shm_checkpoints and not self.hf_repo: + raise ValueError( + "--shm-checkpoints requires --hf-repo " + "(HF is the only durable store)" + ) + if self.resume is not None: + raise ValueError( + "CotrainConfig does not use the top-level 'resume' field. " + "Set 'resume' on each variant in the 'variants' list instead." + ) + return self + + RunConfig = Annotated[ - Union[PretrainConfig, AdapterConfig], Field(discriminator="run_type") + Union[PretrainConfig, AdapterConfig, CotrainConfig], + Field(discriminator="run_type"), ] """Discriminated union of all run config types.""" diff --git a/scripts/train.py b/scripts/train.py index 36415be..3e5e7e5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -30,7 +30,7 @@ import torch.multiprocessing as mp from pydantic import TypeAdapter -from pawn.run_config import AdapterConfig, PretrainConfig +from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig # ----------------------------------------------------------------------- @@ -490,21 +490,27 @@ def main() -> None: raw = _parse_cli() run_type = raw.get("run_type") - if run_type not in ("pretrain", "adapter"): + if run_type not in ("pretrain", "adapter", "cotrain"): _die( - f"run_type must be 'pretrain' or 'adapter', got {run_type!r}. " + f"run_type must be 'pretrain', 'adapter', or 'cotrain', got {run_type!r}. " "Specify via --run-type or in the JSON config." ) - ta = TypeAdapter( - PretrainConfig if run_type == "pretrain" else AdapterConfig - ) + config_cls = { + "pretrain": PretrainConfig, + "adapter": AdapterConfig, + "cotrain": CotrainConfig, + }[run_type] # type: ignore[index] # narrowed by `not in` check above + ta = TypeAdapter(config_cls) config = ta.validate_python(raw) if isinstance(config, PretrainConfig): run_pretrain(config) elif isinstance(config, AdapterConfig): run_adapter(config) + elif isinstance(config, CotrainConfig): + from pawn.cotrain import run_cotrain + run_cotrain(config) else: _die(f"Unknown run_type: {run_type}") diff --git a/scripts/train_all.py b/scripts/train_all.py index 3300f07..ca60442 100644 --- a/scripts/train_all.py +++ b/scripts/train_all.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -"""Train small, base, and large PAWN models simultaneously on shared data. +"""Train multiple PAWN model variants simultaneously on shared data. -All three models see the exact same batches in the same order, eliminating +All models see the exact same batches in the same order, eliminating data generation overhead and ensuring comparable training conditions. Usage: @@ -13,275 +13,16 @@ import argparse import json -import math import os -import shutil -import signal import sys -import time from pathlib import Path import torch import torch.multiprocessing as mp -import torch.nn.functional as F -from torch.utils.data import DataLoader -from pawn.config import CLMConfig, TrainingConfig +from pawn.cotrain import ModelSlot, run_cotrain from pawn.model import PAWNCLM -from pawn.data import CLMDataset, create_validation_set -from pawn.gpu import configure_gpu -from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf -from pawn.logging import MetricsLogger, random_slug - - -# --------------------------------------------------------------------------- -# Per-model state -# --------------------------------------------------------------------------- - -class ModelSlot: - """Holds everything needed to train and checkpoint one model variant.""" - - def __init__( - self, - name: str, - model_cfg: CLMConfig, - train_cfg: TrainingConfig, - device: str, - hf_repo: str | None, - shm_checkpoints: bool = False, - slug: str = "", - ): - self.name = name - self.slug = slug - self.model_cfg = model_cfg - self.train_cfg = train_cfg - self.device = device - self.hf_repo = hf_repo - self.shm_checkpoints = shm_checkpoints - - self.model = PAWNCLM(model_cfg).to(device) - param_count = sum(p.numel() for p in self.model.parameters()) - print(f" {name}: {param_count:,} params ({model_cfg.d_model}d/{model_cfg.n_layers}L)") - - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=train_cfg.lr, - weight_decay=train_cfg.weight_decay, - ) - - from pawn.trainer import CosineWithWarmup - self.scheduler = CosineWithWarmup( - self.optimizer, - warmup_steps=train_cfg.warmup_steps, - total_steps=train_cfg.total_steps, - ) - - self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp) - - # Logger (creates run directory) - self.logger = MetricsLogger( - train_cfg.log_dir, run_prefix="run", device=device, - slug=slug, suffix=name, - ) - self.run_dir = str(self.logger.run_dir) - self.jsonl_path = str(self.logger.metrics_path) - - # Checkpoint directory: /dev/shm if requested, else under run_dir - if shm_checkpoints: - self.checkpoint_dir = f"/dev/shm/pawn_checkpoints/{name}" - else: - self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints") - os.makedirs(self.checkpoint_dir, exist_ok=True) - - self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None - self.global_step = 0 - self.best_val_step = 0 - self.best_val_loss = float("inf") - self.patience_counter = 0 - self.stopped = False - - # Background HF push (one thread per slot, so pushes don't block training) - from concurrent.futures import ThreadPoolExecutor - self._hf_push_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"hf-{name}") - self._hf_push_future = None - - self.logger.log_config( - model=model_cfg.__dict__, - training=train_cfg.__dict__, - param_count=param_count, - formulation="clm", - multi_model=True, - variant=name, - ) - self.logger.write_config_json( - model=model_cfg.__dict__, - training=train_cfg.__dict__, - param_count=param_count, - formulation="clm", - multi_model=True, - variant=name, - ) - - def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """Forward + backward. Returns raw GPU tensor metrics (no .item() sync).""" - self.model.train() - input_ids = batch["input_ids"].to(self.device, non_blocking=True) - targets = batch["targets"].to(self.device, non_blocking=True) - loss_mask = batch["loss_mask"].to(self.device, non_blocking=True) - - with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp): - loss, metrics = self.model.forward_train(input_ids, loss_mask, targets) - - self.scaler.scale(loss).backward() - return metrics - - def optimizer_step(self) -> float: - self.scaler.unscale_(self.optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.train_cfg.max_grad_norm - ).item() - self.scaler.step(self.optimizer) - self.scaler.update() - self.optimizer.zero_grad(set_to_none=True) - self.scheduler.step() - return grad_norm - - def _unwrapped_model(self): - """Return the unwrapped model (strips torch.compile wrapper).""" - m = self.model - while hasattr(m, '_orig_mod'): - m = m._orig_mod - return m - - def save_checkpoint(self): - path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}") - save_pretrain_checkpoint( - path, self._unwrapped_model(), self.optimizer, self.scheduler, self.scaler, - self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__, - ) - print(f" [{self.name}] Checkpoint saved: {path}") - - if self.hf_repo and self.hf_branch: - self._push_to_hf_async(path, self.global_step) - - def _push_to_hf_async(self, ckpt_path: str, step: int): - """Push checkpoint to HuggingFace in a background thread.""" - # Wait for any previous push to finish before starting a new one - if self._hf_push_future is not None: - self._hf_push_future.result() # raises if previous push failed - - def _push(): - try: - push_checkpoint_to_hf( - ckpt_path, self.hf_repo, self.hf_branch, - metrics_path=self.jsonl_path, step=step, - ) - print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}") - - # On /dev/shm, clean up old checkpoints after successful push. - # Keep the latest (just saved) and the best (for post-training evals). - if self.shm_checkpoints: - keep = {Path(ckpt_path).name, f"step_{self.best_val_step:08d}"} - for old in sorted(Path(self.checkpoint_dir).glob("step_*")): - if old.name not in keep: - shutil.rmtree(old, ignore_errors=True) - except Exception as e: - print(f" [{self.name}] WARNING: HF push failed: {e}") - - self._hf_push_future = self._hf_push_pool.submit(_push) - - def push_metrics_to_hf(self): - """Push just metrics.jsonl to HF (lightweight, no checkpoint).""" - if not self.hf_repo or not self.hf_branch: - return - - def _push_metrics(): - try: - from huggingface_hub import HfApi - api = HfApi() - api.create_branch(self.hf_repo, repo_type="model", - branch=self.hf_branch, exist_ok=True) - api.upload_file( - path_or_fileobj=self.jsonl_path, - path_in_repo="metrics.jsonl", - repo_id=self.hf_repo, - repo_type="model", - revision=self.hf_branch, - commit_message=f"Metrics through step {self.global_step}", - ) - except Exception as e: - print(f" [{self.name}] WARNING: metrics push failed: {e}") - - # Fire and forget on the push pool (queued behind any checkpoint push) - self._hf_push_pool.submit(_push_metrics) - - def wait_for_push(self): - """Block until any in-flight HF push completes.""" - if self._hf_push_future is not None: - self._hf_push_future.result() - self._hf_push_future = None - - @torch.no_grad() - def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]: - from pawn.trainer import compute_legal_move_rate_from_preds - - self.model.eval() - n = val_data["input_ids"].shape[0] - batch_size = self.train_cfg.batch_size - total_metrics: dict[str, float] = {} - n_batches = 0 - has_legal = "legal_grid" in val_data - - for start in range(0, n, batch_size): - end = min(start + batch_size, n) - input_ids = val_data["input_ids"][start:end].to(self.device, non_blocking=True) - targets = val_data["targets"][start:end].to(self.device, non_blocking=True) - loss_mask = val_data["loss_mask"][start:end].to(self.device, non_blocking=True) - - with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp): - # Get hidden states without materializing full (B,T,V) logits - hidden = self.model.forward_eval(input_ids, loss_mask) - - # Sparse projection: only valid positions through lm_head - valid_hidden = hidden[loss_mask] - valid_logits = self.model.lm_head(valid_hidden) - valid_targets = targets[loss_mask] - - loss = F.cross_entropy(valid_logits, valid_targets) - accuracy = (valid_logits.argmax(-1) == valid_targets).float().mean().item() - metrics: dict[str, float] = {"loss": loss.item(), "accuracy": accuracy} - - # Top-5 accuracy - top5 = valid_logits.topk(5, dim=-1).indices - metrics["top5_accuracy"] = ( - (top5 == valid_targets.unsqueeze(-1)).any(dim=-1).float().mean().item() - ) - - # Legal move rate: reuse already-computed valid_logits argmax - if has_legal: - legal_grid = val_data["legal_grid"][start:end].to(self.device, non_blocking=True) - game_lengths = val_data["game_lengths"][start:end].to(self.device, non_blocking=True) - preds = torch.zeros_like(loss_mask, dtype=torch.long) - preds[loss_mask] = valid_logits.argmax(dim=-1) - metrics["legal_move_rate"] = compute_legal_move_rate_from_preds( - preds, legal_grid, loss_mask, game_lengths, - n_actions=self.model.embed.n_actions, - ) - - for k, v in metrics.items(): - total_metrics[k] = total_metrics.get(k, 0.0) + v - n_batches += 1 - - avg = {f"val/{k}": v / n_batches for k, v in total_metrics.items()} - avg["val/perplexity"] = math.exp(min(avg["val/loss"], 20.0)) - return avg - - def close(self): - self.wait_for_push() - self._hf_push_pool.shutdown(wait=True) - self.logger.close() - - +from pawn.run_config import CotrainConfig, CotrainVariant # --------------------------------------------------------------------------- @@ -325,6 +66,34 @@ def parse_args(): return p.parse_args() +def _args_to_cotrain_config(args) -> CotrainConfig: + """Build a CotrainConfig from argparse namespace.""" + variants = [ + CotrainVariant(name="small", variant="small", legacy_vocab=args.legacy_vocab), + CotrainVariant(name="base", variant="base", legacy_vocab=args.legacy_vocab), + CotrainVariant(name="large", variant="large", legacy_vocab=args.legacy_vocab), + ] + + return CotrainConfig( + variants=variants, + device=args.device or ("cuda" if torch.cuda.is_available() else "cpu"), + total_steps=args.total_steps, + batch_size=args.batch_size, + num_workers=args.num_workers, + log_dir=args.log_dir, + log_interval=args.log_interval, + eval_interval=args.eval_interval, + checkpoint_interval=args.checkpoint_interval, + discard_ply_limit=args.discard_ply_limit, + no_outcome_token=args.no_outcome_token, + patience=args.patience, + wandb=args.wandb, + hf_repo=args.hf_repo, + local_checkpoints=args.local_checkpoints, + shm_checkpoints=args.shm_checkpoints, + ) + + def _run_post_training_evals(slots: list[ModelSlot], args): """Run probes, diagnostics, and Lichess eval on best checkpoint per variant.""" from pawn.eval_suite.probes import extract_probe_data, train_all_probes @@ -420,239 +189,16 @@ def _run_post_training_evals(slots: list[ModelSlot], args): def main(): args = parse_args() + config = _args_to_cotrain_config(args) + slots = run_cotrain(config) - if args.shm_checkpoints and not args.hf_repo: - print("ERROR: --shm-checkpoints requires --hf-repo (HF is the only durable store)") - sys.exit(1) - - device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") - if device == "cuda": - gpu_cfg = configure_gpu() - import pawn.model as model_module - if gpu_cfg.get("sdpa_backend"): - model_module.SDPA_BACKEND = gpu_cfg["sdpa_backend"] - - # Build per-variant configs (shared training hyperparams, different model sizes) - variants = { - "small": CLMConfig.small(), - "base": CLMConfig.base(), - "large": CLMConfig.large(), - } - - slug = random_slug() - - print(f"=== Multi-Model Training [{slug}] ===") - print(f"Device: {device}") - print(f"Batch size: {args.batch_size}") - print(f"Total steps: {args.total_steps}") - if args.shm_checkpoints: - print("Checkpoints: /dev/shm (volatile, HF push is durable store)") - if args.no_outcome_token: - print("Outcome token: DISABLED (ablation experiment)") - print() - - # Linear LR scaling: lr = base_lr * (batch_size / base_batch_size) - base_batch_size = 256 - base_lr = TrainingConfig.lr - scaled_lr = base_lr * (args.batch_size / base_batch_size) - print(f"LR: {scaled_lr:.2e} (scaled from {base_lr:.2e} for batch {args.batch_size})") - - if args.legacy_vocab: - from pawn.config import LegacyVocab - print("Using legacy PAWN vocabulary (4284 tokens, 256 seq_len)") - - slots: list[ModelSlot] = [] - for name, model_cfg in variants.items(): - if args.legacy_vocab: - model_cfg.vocab_size = LegacyVocab.VOCAB_SIZE - model_cfg.max_seq_len = 256 - - train_cfg = TrainingConfig() - train_cfg.lr = scaled_lr - train_cfg.total_steps = args.total_steps - train_cfg.batch_size = args.batch_size - train_cfg.num_workers = args.num_workers - train_cfg.device = device - train_cfg.log_dir = args.log_dir - train_cfg.log_interval = args.log_interval - train_cfg.eval_interval = args.eval_interval - train_cfg.checkpoint_interval = args.checkpoint_interval - train_cfg.discard_ply_limit = args.discard_ply_limit - train_cfg.no_outcome_token = args.no_outcome_token - train_cfg.use_wandb = args.wandb - train_cfg.max_ply = model_cfg.max_seq_len - - hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None - slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo, - shm_checkpoints=args.shm_checkpoints, slug=slug)) - - # Shared dataset and validation set - max_ply = model_cfg.max_seq_len # 512 (new default) or 256 (legacy) - dataset = CLMDataset( - args.batch_size, max_ply, base_seed=42, - discard_ply_limit=args.discard_ply_limit, - no_outcome=args.no_outcome_token, - ) - - print("\nGenerating shared validation set...") - val_data = create_validation_set(512, max_ply, seed=(2**63) - 1, - discard_ply_limit=args.discard_ply_limit, - no_outcome=args.no_outcome_token) - - # Compile models - if device != "cpu": - for slot in slots: - try: - slot.model = torch.compile(slot.model, mode="default") - print(f" [{slot.name}] torch.compile enabled") - except Exception: - print(f" [{slot.name}] torch.compile not available") - - loader = DataLoader( - dataset, - batch_size=None, - num_workers=args.num_workers, - pin_memory=(device != "cpu"), - persistent_workers=(args.num_workers > 0), - prefetch_factor=2 if args.num_workers > 0 else None, - ) - - # Signal handling - _shutdown_requested = False - _shutdown_signal = None - - def _graceful_exit(signum, frame): - nonlocal _shutdown_requested, _shutdown_signal - _shutdown_requested = True - _shutdown_signal = signum - - signal.signal(signal.SIGTERM, _graceful_exit) - signal.signal(signal.SIGINT, _graceful_exit) - - # Training loop - global_step = 0 - step_start = time.time() - - print(f"\nStarting training from step 0", flush=True) - for slot in slots: - print(f" [{slot.name}] JSONL: {slot.jsonl_path}", flush=True) - print() - - active_slots = list(slots) # slots still training - - for batch in loader: - # Forward + backward + optimizer step per model so CUDA can overlap - # Adam updates (memory-bound) with the next model's forward (compute-bound) - all_metrics: dict[str, dict[str, torch.Tensor]] = {} - all_grad_norms: dict[str, float] = {} - for slot in active_slots: - metrics = slot.train_step(batch) - all_metrics[slot.name] = metrics - gn = slot.optimizer_step() - all_grad_norms[slot.name] = gn - - global_step += 1 - for slot in slots: - slot.global_step = global_step - - step_time = time.time() - step_start - games_per_sec = args.batch_size / step_time - - # Logging — .item() sync only at log intervals - if global_step % args.log_interval == 0: - active_names = ", ".join(s.name for s in active_slots) - print(f"step {global_step:>7d} | {games_per_sec:.0f} g/s | {step_time:.2f}s | active: {active_names}", flush=True) - for slot in active_slots: - m = all_metrics[slot.name] - loss_val = m['loss'].item() - acc_val = m['accuracy'].item() - gn = all_grad_norms[slot.name] - lr = slot.scheduler.get_lr() - print(f" {slot.name:>5s}: loss {loss_val:.4f} | acc {acc_val:.3f} | " - f"lr {lr:.2e} | gn {gn:.2f}", flush=True) - - slot.logger.log_train( - step=global_step, - lr=lr, grad_norm=gn, - step_time=step_time, games_per_sec=games_per_sec, - **{"train/loss": loss_val, "train/accuracy": acc_val}, - ) - - # Eval - if global_step % args.eval_interval == 0: - for slot in active_slots: - val_metrics = slot.evaluate(val_data) - print(f" {slot.name:>5s} val: loss {val_metrics['val/loss']:.4f} | " - f"acc {val_metrics['val/accuracy']:.3f}", flush=True) - # Track best for eval, /dev/shm cleanup, and patience - vl = val_metrics["val/loss"] - if vl < slot.best_val_loss: - slot.best_val_loss = vl - slot.best_val_step = global_step - slot.patience_counter = 0 - else: - slot.patience_counter += 1 - - slot.logger.log_val( - step=global_step, - patience=slot.patience_counter, - best_val_loss=slot.best_val_loss, - best_val_step=slot.best_val_step, - **val_metrics, - ) - - # Per-model early stopping - if args.patience > 0 and slot.patience_counter >= args.patience: - print(f" [{slot.name}] Early stopping — no improvement " - f"for {args.patience} evals (best step {slot.best_val_step})") - slot.stopped = True - slot.save_checkpoint() - - active_slots = [s for s in active_slots if not s.stopped] - - # Push metrics to HF after eval (lightweight, background) - for slot in slots: - slot.push_metrics_to_hf() - - if not active_slots: - print(f"\nAll models stopped at step {global_step}") - break - - # Checkpoint - if global_step % args.checkpoint_interval == 0: - for slot in active_slots: - slot.save_checkpoint() - - # Done? - if global_step >= args.total_steps: - print(f"\nTraining complete at step {global_step}") - for slot in active_slots: - slot.save_checkpoint() - break - - # Graceful shutdown - if _shutdown_requested: - print(f"\nShutdown requested (signal {_shutdown_signal}), " - f"saving checkpoints at step {global_step}...") - for slot in active_slots: - slot.save_checkpoint() - break - - step_start = time.time() - - # Cleanup - for slot in slots: - slot.close() - - # Post-training evals + # Post-training evals (CLI-only feature, not available through the lab) if args.run_evals: print("\n" + "=" * 60) print("POST-TRAINING EVALUATION") print("=" * 60) _run_post_training_evals(slots, args) - print("\nAll done.") - if __name__ == "__main__": try: diff --git a/tests/core/test_run_config.py b/tests/core/test_run_config.py index f240d50..d277c7e 100644 --- a/tests/core/test_run_config.py +++ b/tests/core/test_run_config.py @@ -13,6 +13,8 @@ from pawn.run_config import ( AdapterConfig, BaseRunConfig, + CotrainConfig, + CotrainVariant, PretrainConfig, RunConfig, ) @@ -205,6 +207,155 @@ def test_defaults(self): assert cfg.val_every == 1 +# --------------------------------------------------------------------------- +# CotrainConfig +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCotrainConfig: + def test_valid_three_variants(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[ + CotrainVariant(name="small", variant="small"), + CotrainVariant(name="base", variant="base"), + CotrainVariant(name="large", variant="large"), + ], + ) + assert cfg.run_type == "cotrain" + assert len(cfg.variants) == 3 + + def test_single_variant(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[CotrainVariant(name="only", variant="toy")], + ) + assert len(cfg.variants) == 1 + + def test_custom_architecture_overrides(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[ + CotrainVariant( + name="custom", variant="base", + d_model=384, n_layers=6, n_heads=6, d_ff=1536, + ), + ], + ) + v = cfg.variants[0] + assert v.d_model == 384 + assert v.n_layers == 6 + + def test_empty_variants_rejected(self): + with pytest.raises(ValidationError) as exc_info: + CotrainConfig(local_checkpoints=True, variants=[]) + assert "at least one" in str(exc_info.value).lower() + + def test_duplicate_names_rejected(self): + with pytest.raises(ValidationError) as exc_info: + CotrainConfig( + local_checkpoints=True, + variants=[ + CotrainVariant(name="dup", variant="small"), + CotrainVariant(name="dup", variant="base"), + ], + ) + assert "unique" in str(exc_info.value).lower() + + def test_shm_without_hf_rejected(self): + with pytest.raises(ValidationError) as exc_info: + CotrainConfig( + local_checkpoints=True, + shm_checkpoints=True, + variants=[CotrainVariant(name="x", variant="toy")], + ) + assert "hf-repo" in str(exc_info.value).lower() or "hf_repo" in str(exc_info.value).lower() + + def test_shm_with_hf_accepted(self): + cfg = CotrainConfig( + hf_repo="user/repo", + shm_checkpoints=True, + variants=[CotrainVariant(name="x", variant="toy")], + ) + assert cfg.shm_checkpoints is True + + def test_top_level_resume_rejected(self): + with pytest.raises(ValidationError) as exc_info: + CotrainConfig( + local_checkpoints=True, + resume="/some/path", + variants=[CotrainVariant(name="x", variant="toy")], + ) + assert "per" in str(exc_info.value).lower() or "variant" in str(exc_info.value).lower() + + def test_default_val_games(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[CotrainVariant(name="x", variant="toy")], + ) + assert cfg.val_games == 512 + + def test_default_checkpoint_interval(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[CotrainVariant(name="x", variant="toy")], + ) + assert cfg.checkpoint_interval == 5000 + + def test_variant_resume_path(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[ + CotrainVariant(name="a", variant="toy", resume="/tmp/ckpt_a"), + CotrainVariant(name="b", variant="toy", resume="/tmp/ckpt_b"), + ], + ) + assert cfg.variants[0].resume == "/tmp/ckpt_a" + assert cfg.variants[1].resume == "/tmp/ckpt_b" + + def test_serialization_roundtrip(self): + cfg = CotrainConfig( + local_checkpoints=True, + total_steps=1000, + batch_size=64, + variants=[ + CotrainVariant(name="small", variant="small"), + CotrainVariant(name="base", variant="base", d_model=384), + ], + ) + data = cfg.model_dump() + cfg2 = CotrainConfig(**data) + assert cfg == cfg2 + + def test_json_schema_generates(self): + schema = CotrainConfig.model_json_schema() + assert isinstance(schema, dict) + assert "properties" in schema + assert "variants" in schema["properties"] + assert "run_type" in schema["properties"] + + +@pytest.mark.unit +class TestCotrainVariant: + def test_defaults(self): + v = CotrainVariant(name="test") + assert v.variant == "base" + assert v.d_model is None + assert v.max_seq_len == 512 + assert v.legacy_vocab is False + assert v.resume is None + + def test_all_variant_presets(self): + for preset in ["toy", "small", "base", "large"]: + v = CotrainVariant(name=preset, variant=preset) + assert v.variant == preset + + def test_invalid_variant_rejected(self): + with pytest.raises(ValidationError): + CotrainVariant(name="x", variant="enormous") + + # --------------------------------------------------------------------------- # Base fields shared across configs # --------------------------------------------------------------------------- @@ -282,6 +433,17 @@ def test_union_dispatches_adapter(self): assert cfg.run_type == "adapter" assert cfg.strategy == "lora" + def test_union_dispatches_cotrain(self): + adapter = TypeAdapter(RunConfig) + cfg = adapter.validate_python({ + "run_type": "cotrain", + "local_checkpoints": True, + "variants": [{"name": "s", "variant": "small"}], + }) + assert isinstance(cfg, CotrainConfig) + assert cfg.run_type == "cotrain" + assert len(cfg.variants) == 1 + def test_union_missing_run_type(self): adapter = TypeAdapter(RunConfig) with pytest.raises(ValidationError): @@ -332,6 +494,29 @@ def test_pretrain_json_string_roundtrip(self): cfg2 = PretrainConfig(**d) assert cfg == cfg2 + def test_cotrain_json_roundtrip(self): + cfg = CotrainConfig( + local_checkpoints=True, + total_steps=500, + variants=[ + CotrainVariant(name="a", variant="toy"), + CotrainVariant(name="b", variant="small", d_model=128), + ], + ) + data = cfg.model_dump() + cfg2 = CotrainConfig(**data) + assert cfg == cfg2 + + def test_cotrain_json_string_roundtrip(self): + cfg = CotrainConfig( + local_checkpoints=True, + variants=[CotrainVariant(name="x", variant="toy")], + ) + s = cfg.model_dump_json() + d = json.loads(s) + cfg2 = CotrainConfig(**d) + assert cfg == cfg2 + def test_pretrain_json_schema_generates(self): schema = PretrainConfig.model_json_schema() assert isinstance(schema, dict) diff --git a/tests/lab/test_monitor.py b/tests/lab/test_monitor.py index 81bfa30..57665a0 100644 --- a/tests/lab/test_monitor.py +++ b/tests/lab/test_monitor.py @@ -250,6 +250,156 @@ def test_step_time_zero_uses_elapsed(self, tmp_path): assert trial.steps_per_sec == pytest.approx(2.0) +# ===================================================================== +# read_metrics — cotrain +# ===================================================================== + + +class TestReadMetricsCotrain: + def _make_cotrain_trial(self, trial_id: int = 0) -> Trial: + return Trial( + trial_id=trial_id, + strategy="cotrain:small+base", + params={}, + cli_command=[], + config={"run_type": "cotrain"}, + ) + + def test_discovers_multiple_variant_dirs(self, tmp_path): + """Cotrain read_metrics discovers per-variant JSONL files.""" + log_dir = tmp_path / "logs" + trial_dir = log_dir / "trial_0000" + + # Create two variant run dirs with metrics + for name in ("small", "base"): + metrics_path = trial_dir / f"run_20260410_120000_{name}_calm-crane" / "metrics.jsonl" + _write_metrics_file(metrics_path, [ + {"type": "config", "total_steps": 100, "param_count": 1000}, + {"type": "train", "step": 10, "train/loss": 2.5, "step_time": 0.1}, + {"type": "val", "step": 10, "val/loss": 2.0, "val/accuracy": 0.4}, + ]) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.run_dir == str(trial_dir) + assert trial.variants is not None + assert "small" in trial.variants + assert "base" in trial.variants + assert trial.variants["small"]["current_step"] == 10 + assert trial.variants["base"]["current_step"] == 10 + + def test_aggregates_current_step_as_min(self, tmp_path): + """Trial.current_step is min across variants.""" + log_dir = tmp_path / "logs" + trial_dir = log_dir / "trial_0000" + + small_path = trial_dir / "run_20260410_120000_small_calm-crane" / "metrics.jsonl" + _write_metrics_file(small_path, [ + {"type": "train", "step": 50, "train/loss": 2.0, "step_time": 0.1}, + ]) + base_path = trial_dir / "run_20260410_120000_base_calm-crane" / "metrics.jsonl" + _write_metrics_file(base_path, [ + {"type": "train", "step": 30, "train/loss": 2.5, "step_time": 0.1}, + ]) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.current_step == 30 # min of 50, 30 + + def test_aggregates_best_val_loss_as_min(self, tmp_path): + """Trial.best_val_loss is min across variants.""" + log_dir = tmp_path / "logs" + trial_dir = log_dir / "trial_0000" + + small_path = trial_dir / "run_20260410_120000_small_calm-crane" / "metrics.jsonl" + _write_metrics_file(small_path, [ + {"type": "val", "step": 10, "val/loss": 3.0}, + ]) + base_path = trial_dir / "run_20260410_120000_base_calm-crane" / "metrics.jsonl" + _write_metrics_file(base_path, [ + {"type": "val", "step": 10, "val/loss": 2.0}, + ]) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.best_val_loss == pytest.approx(2.0) + + def test_per_variant_offsets_incremental(self, tmp_path): + """Cotrain uses (trial_id, variant_name) offset keys for incremental reads.""" + log_dir = tmp_path / "logs" + trial_dir = log_dir / "trial_0000" + + small_path = trial_dir / "run_20260410_120000_small_calm-crane" / "metrics.jsonl" + _write_metrics_file(small_path, [ + {"type": "train", "step": 10, "train/loss": 2.5, "step_time": 0.1}, + ]) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.variants["small"]["current_step"] == 10 + + # Append more data + with open(small_path, "a") as f: + f.write(json.dumps({"type": "train", "step": 20, "train/loss": 2.0, "step_time": 0.1}) + "\n") + + read_metrics(trial, log_dir, offsets) + assert trial.variants["small"]["current_step"] == 20 + + def test_underscore_in_variant_name(self, tmp_path): + """Variant names containing underscores are parsed correctly.""" + log_dir = tmp_path / "logs" + trial_dir = log_dir / "trial_0000" + + metrics_path = trial_dir / "run_20260410_120000_my_model_calm-crane" / "metrics.jsonl" + _write_metrics_file(metrics_path, [ + {"type": "train", "step": 5, "train/loss": 3.0, "step_time": 0.2}, + ]) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.variants is not None + assert "my_model" in trial.variants + assert trial.variants["my_model"]["current_step"] == 5 + + def test_empty_trial_dir_noops(self, tmp_path): + log_dir = tmp_path / "logs" + (log_dir / "trial_0000").mkdir(parents=True) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.variants is None + assert trial.run_dir is None + + def test_variant_run_dir_tracked(self, tmp_path): + """Each variant's run_dir is stored independently.""" + log_dir = tmp_path / "logs" + trial_dir = log_dir / "trial_0000" + + small_dir = trial_dir / "run_20260410_120000_small_calm-crane" + metrics_path = small_dir / "metrics.jsonl" + _write_metrics_file(metrics_path, [ + {"type": "train", "step": 5, "train/loss": 3.0, "step_time": 0.2}, + ]) + + trial = self._make_cotrain_trial(0) + offsets: dict = {} + read_metrics(trial, log_dir, offsets) + + assert trial.variants["small"]["run_dir"] == str(small_dir) + + # ===================================================================== # check_health # =====================================================================