Skip to content

Add co-training support to pawn-lab MCP server#58

Merged
thomas-schweich merged 3 commits intomainfrom
cotrain-lab-support
Apr 12, 2026
Merged

Add co-training support to pawn-lab MCP server#58
thomas-schweich merged 3 commits intomainfrom
cotrain-lab-support

Conversation

@thomas-schweich
Copy link
Copy Markdown
Owner

@thomas-schweich thomas-schweich commented Apr 12, 2026

Summary

  • Adds run_type="cotrain" to the lab MCP server, enabling multi-model pretraining runs (the train_all.py workflow) to be launched, monitored, killed, and resumed through lab_launch/lab_status/lab_kill/lab_resume
  • Variants are fully generic CotrainVariant specs (arbitrary architecture overrides, not just small/base/large), with per-variant resume paths for seamless lab_resume
  • Extracts ModelSlot + training loop from scripts/train_all.py into pawn/cotrain.py so both the lab and CLI share one implementation; train_all.py becomes a thin argparse shim

Key changes

File What
pawn/run_config.py CotrainVariant, CotrainConfig, updated RunConfig union
pawn/cotrain.py New — ModelSlot + run_cotrain() with resume & pause support
scripts/train_all.py Thin CLI shim → CotrainConfigrun_cotrain()
scripts/train.py cotrain dispatch branch
pawn/lab/runner.py _validate_config + resume_trial cotrain support
pawn/lab/monitor.py Multi-file metrics discovery, per-variant offsets, aggregation
pawn/lab/state.py Trial.variants field
pawn/lab/server.py lab_schema + docstring updates

Test plan

  • 19 new unit tests for CotrainConfig validation, serialization, discriminated union dispatch, and cotrain metrics monitoring
  • Full test suite passes (1374/1374)
  • Pyright clean on all modified files (pre-existing issues in cotrain.py inherited from train_all.py)
  • Smoke test: lab_launch({"run_type": "cotrain", "variants": [{"name": "a", "variant": "toy"}, {"name": "b", "variant": "toy"}], "total_steps": 50, "batch_size": 16, "local_checkpoints": true}) on a GPU pod
  • Resume smoke test: lab_resume(trial_id, total_steps=100) after completion
  • CLI backward compat: scripts/train_all.py --local-checkpoints --total-steps 50

lab_launch now accepts run_type="cotrain" with a list of variant specs,
enabling multi-model pretraining runs (the equivalent of train_all.py)
to be launched, monitored, killed, and resumed through the lab tools.

- Add CotrainVariant and CotrainConfig to pawn/run_config.py with
  validators for unique names, non-empty variants, and shm/hf coupling
- Extract ModelSlot + training loop from train_all.py into pawn/cotrain.py
  with resume support (per-variant checkpoint loading) and pause_after_steps
- Convert scripts/train_all.py to thin CLI shim over run_cotrain()
- Add cotrain dispatch branch in scripts/train.py
- Update lab runner: _validate_config accepts cotrain, resume_trial
  discovers per-variant checkpoints and sets per-variant resume paths
- Update lab monitor: multi-file metrics discovery for cotrain trials
  with per-variant offset tracking and aggregation to trial level
- Update lab server: lab_schema exposes cotrain, updated docstrings
- Add Trial.variants field for per-variant state tracking
- 19 new tests (config validation, serialization, monitor aggregation)
@thomas-schweich
Copy link
Copy Markdown
Owner Author

@claude

@claude
Copy link
Copy Markdown

claude bot commented Apr 12, 2026

Claude finished @thomas-schweich's task in 3m 56s —— View job


PR #58 Code Review

Overall this is a clean, well-structured extraction. The ModelSlot abstraction is good, the discriminated union / Pydantic validation layer is solid, and the incremental offset-based metrics reading for cotrain is the right approach. Test coverage is thorough. A few issues worth addressing:


Bug: configure_gpu() ignores user-provided flags in run_cotrain

pawn/cotrain.py:373-377

gpu_cfg = configure_gpu()
import pawn.model as model_module
if gpu_cfg.get("sdpa_backend"):
    model_module.SDPA_BACKEND = gpu_cfg["sdpa_backend"]

configure_gpu() is called with no arguments, so config.sdpa_math, config.no_compile, and config.no_amp are all silently ignored. The no_compile flag is handled separately later (lines 445-451), but sdpa_math is never forwarded. On ROCm the auto-detection in configure_gpu() should still select MATH SDPA, but an explicit --sdpa-math override has no effect.

Compare with the adapter path in scripts/train.py:296-302 which correctly passes all three flags. Also, configure_gpu returns use_amp/amp_dtype that are ignored here — train_cfg.use_amp is set independently from config.amp_dtype in _build_variant_configs:346, which is correct, but the import/apply pattern is inconsistent.

Fix this →


Bug: _extract_variant_name breaks on underscores in variant names

pawn/lab/monitor.py:239-251

def _extract_variant_name(run_dir_name: str) -> str | None:
    parts = run_dir_name.split("_")
    if len(parts) < 5 or parts[0] != "run":
        return None
    return parts[-1]   # assumes variant name has no underscores

CotrainVariant.name is an unrestricted str. A name like "my_model" produces a dir run_20260410_151230_slug_my_model, and parts[-1] returns "model" instead of "my_model". The variant state would be keyed under "model" while _resolve_cotrain_resume would look for "my_model", failing to find it and raising on resume.

Fix options: (a) restrict CotrainVariant.name to match ^[a-zA-Z0-9-]+$ (no underscores) with a Field(pattern=...), or (b) extract the name by removing the known-length prefix run_YYYYMMDD_HHMMSS_slug_ (i.e., skip the first 4 underscore-joined tokens).

Fix this →


Misleading: _find_best_checkpoint never actually finds a "best" checkpoint for cotrain/pretrain

pawn/lab/runner.py:335-348

ckpt_dir = ckpt_base / "best"
if not ckpt_dir.exists():
    ckpt_dir = ckpt_base / "final"
if not ckpt_dir.exists():
    step_dirs = sorted(ckpt_base.glob("step_*"))
    if step_dirs:
        ckpt_dir = step_dirs[-1]

ModelSlot.save_checkpoint only ever writes step_{n:08d}/ directories — there is no best/ or final/ symlink created for pretrain or cotrain. This method always falls through to step_dirs[-1] (the lexicographically last, which is the highest-numbered, checkpoint). The name and docstring imply val-loss-optimal selection, but it's actually "latest checkpoint". This matters when there are trailing checkpoints after the best val epoch (e.g., the model was saved at steps 5000, 10000 …, best val was at 7500, but there's only a step_10000 checkpoint). This is a pre-existing design for single-model pretrain runs, but now it's also the cotrain resume path.

The function is new in this PR; it would be clearer to rename it _find_latest_checkpoint and document the behavior.


Minor: last_train_acc never populated for cotrain trials

pawn/lab/monitor.py:213-222 (_read_cotrain_metrics)

The cotrain train-record handler reads train/loss and step_time but not train/accuracy. The _aggregate_cotrain_metrics function also has no last_train_acc aggregation. As a result, trial.last_train_acc is always None for cotrain trials, which shows up as train_acc: None in lab_status. Not a correctness issue, but adds noise to the monitoring output.


Minor: DataLoader in run_cotrain doesn't set multiprocessing_context='spawn'

pawn/cotrain.py:453-460

The CLAUDE.md explicitly documents that DataLoader workers must use multiprocessing_context='spawn' to prevent rayon deadlocks. The run_cotrain DataLoader omits this, relying on mp.set_start_method("spawn") being called by the entry-point scripts. This works when called via scripts/train.py (which calls set_start_method("spawn", force=True)), but is fragile if run_cotrain is ever called in a context that doesn't pre-configure the start method. Should pass multiprocessing_context="spawn" explicitly.


Nit: Misleading comment on _hf_push_future.result()

pawn/cotrain.py:181-182

if self._hf_push_future is not None:
    self._hf_push_future.result()  # raises if previous push failed

The _push closure catches all exceptions and logs them as warnings — it never lets an exception propagate out of the future. So result() will never actually raise from a push failure. The comment is incorrect.


Nit: Top-level resume field in CotrainConfig is silently ignored

CotrainConfig inherits resume: str | None = None from BaseRunConfig, but run_cotrain reads per-variant variant_spec.resume only. A user passing --resume /path/to/ckpt to train.py --run-type cotrain (or setting resume in the JSON config) gets no error and no effect. A validator in CotrainConfig that raises on self.resume is not None with a helpful message ("use per-variant resume fields instead") would prevent silent misconfiguration.


Positive highlights

  • CotrainConfig._check_cotrain validator — catching empty variants, duplicate names, and the shm/hf mutual requirement all in one place is clean.
  • Per-variant offset keys (trial_id, variant_name) in the monitor — the incremental read design is exactly right for multiple concurrent JSONL files.
  • active_slots early stopping — per-model early stopping with the active slots list being pruned is elegant.
  • _build_variant_configs architecture override — the coerce-from-preset then apply-overrides pattern matches run_pretrain() closely, making the two paths consistent.
  • The test suite is comprehensive and well-structured; the _extract_variant_name test cases implicitly test the happy path, but there's no test for underscore-containing variant names.

…_val

The explicit `epoch: int | None` parameter alongside `**metrics: object`
caused pyright to reject callers that spread a `dict[str, float]` (since
a key named "epoch" would be float, not int). The parameter was redundant
— epoch flows through **metrics like every other field. Removing it fixes
all 11 pyright errors in pawn/cotrain.py at the source.
Bug fixes:
- Pass sdpa_math/no_compile/no_amp flags to configure_gpu() in
  run_cotrain so --sdpa-math actually takes effect
- Fix _extract_variant_name to handle underscores in variant names by
  joining parts[3:-1] (variant is between timestamp and slug)
- Reject top-level 'resume' field on CotrainConfig with a helpful error
  directing users to per-variant resume fields

Improvements:
- Rename _find_best_checkpoint → _find_latest_checkpoint with accurate
  docstring (pretrain/cotrain only write step_* dirs, no best/ symlink)
- Populate last_train_acc for cotrain trials in monitor aggregation
- Add explicit multiprocessing_context="spawn" to DataLoader in
  run_cotrain to prevent rayon deadlocks independent of caller context
- Fix misleading comment on _hf_push_future.result() — the closure
  catches all exceptions, so result() blocks but never raises

Tests:
- Fix test dir names to match actual MetricsLogger format
  (run_DATE_TIME_VARIANT_SLUG not run_DATE_TIME_SLUG_VARIANT)
- Add test for underscore-containing variant names
- Add test for top-level resume rejection
@thomas-schweich
Copy link
Copy Markdown
Owner Author

@claude please follow up review

@claude
Copy link
Copy Markdown

claude bot commented Apr 12, 2026

Claude finished @thomas-schweich's task in 3m 34s —— View job


PR #58 Follow-up Review

All 7 issues from the original review have been addressed in commit d756dce. Here's a verification of each fix:


Bug: configure_gpu() ignoring sdpa_math/no_amp flags — Fixed ✓

pawn/cotrain.py:379-384 now correctly passes all three flags:

gpu_cfg = configure_gpu(
    device,
    no_compile=config.no_compile,
    no_amp=(config.amp_dtype == "none"),
    sdpa_math=config.sdpa_math,
)

Bug: _extract_variant_name breaking on underscores — Fixed ✓

pawn/lab/monitor.py:252-257 now uses "_".join(parts[3:-1]) which correctly handles any number of underscores in a variant name. The fix exploits the invariant that random_slug() produces hyphenated slugs (e.g. calm-crane) with no underscores, so parts[-1] is always the slug and parts[3:-1] are always the variant name parts. A new test test_underscore_in_variant_name covers this case explicitly.


Misleading: _find_best_checkpoint never found a "best" checkpoint — Fixed ✓

pawn/lab/runner.py:336 is now _find_latest_checkpoint with a docstring that accurately describes the "latest step_* directory" fallback for pretrain/cotrain runs.


Minor: last_train_acc never populated for cotrain — Fixed ✓

pawn/lab/monitor.py:220-222 now reads train/accuracy in the per-variant train handler, and pawn/lab/monitor.py:289-292 aggregates it as a mean across active variants — consistent with how last_train_loss is handled.


Minor: DataLoader missing multiprocessing_context='spawn' — Fixed ✓

pawn/cotrain.py:470:

multiprocessing_context="spawn" if config.num_workers > 0 else None,

Correctly conditions on num_workers > 0 (the None path avoids a warning when workers=0).


Nit: Misleading comment on _hf_push_future.result() — Fixed ✓

pawn/cotrain.py:182 comment now reads # blocks until previous push completes, which accurately describes the behavior.


Nit: Top-level resume in CotrainConfig silently ignored — Fixed ✓

pawn/run_config.py:184-188 raises a ValueError with a clear message directing users to per-variant resume fields.


One remaining observation

The _extract_variant_name docstring at pawn/lab/monitor.py:244-250 documents the dir format as run_YYYYMMDD_HHMMSS_variantname_slug, but cross-checking with MetricsLogger.__init__ (pawn/logging.py:159-164):

ts = datetime.now().strftime("%Y%m%d_%H%M%S")  # e.g. "20260410_151230"
parts = [run_prefix, ts]
if suffix:
    parts.append(suffix)
parts.append(self.slug)
dir_name = "_".join(parts)

ts itself contains an underscore, so the actual split is ["run", "YYYYMMDD", "HHMMSS", *name_parts, "slug"] — 5+ parts with date and time as separate elements. The code (parts[3:-1]) is correct, but the docstring should say parts[1]=YYYYMMDD, parts[2]=HHMMSS rather than treating them as a single "date_time" token. Minor docstring imprecision, but the logic is sound and the test covers it.


Summary

All 7 review items are fixed correctly. The _extract_variant_name fix is the most consequential — it's solid, is backed by the test_underscore_in_variant_name test, and correctly uses the slug-as-anchor invariant. No new issues identified. The PR is in good shape for merging after the smoke tests listed in the test plan pass.

@thomas-schweich thomas-schweich merged commit c4f3744 into main Apr 12, 2026
1 check passed
@thomas-schweich thomas-schweich deleted the cotrain-lab-support branch April 12, 2026 03:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant