Skip to content

Commit c4f3744

Browse files
Add co-training support to pawn-lab MCP server (#58)
* Add co-training support to pawn-lab MCP server lab_launch now accepts run_type="cotrain" with a list of variant specs, enabling multi-model pretraining runs (the equivalent of train_all.py) to be launched, monitored, killed, and resumed through the lab tools. - Add CotrainVariant and CotrainConfig to pawn/run_config.py with validators for unique names, non-empty variants, and shm/hf coupling - Extract ModelSlot + training loop from train_all.py into pawn/cotrain.py with resume support (per-variant checkpoint loading) and pause_after_steps - Convert scripts/train_all.py to thin CLI shim over run_cotrain() - Add cotrain dispatch branch in scripts/train.py - Update lab runner: _validate_config accepts cotrain, resume_trial discovers per-variant checkpoints and sets per-variant resume paths - Update lab monitor: multi-file metrics discovery for cotrain trials with per-variant offset tracking and aggregation to trial level - Update lab server: lab_schema exposes cotrain, updated docstrings - Add Trial.variants field for per-variant state tracking - 19 new tests (config validation, serialization, monitor aggregation) * Fix pyright errors: drop redundant epoch parameter from log_train/log_val The explicit `epoch: int | None` parameter alongside `**metrics: object` caused pyright to reject callers that spread a `dict[str, float]` (since a key named "epoch" would be float, not int). The parameter was redundant — epoch flows through **metrics like every other field. Removing it fixes all 11 pyright errors in pawn/cotrain.py at the source. * Address PR review feedback Bug fixes: - Pass sdpa_math/no_compile/no_amp flags to configure_gpu() in run_cotrain so --sdpa-math actually takes effect - Fix _extract_variant_name to handle underscores in variant names by joining parts[3:-1] (variant is between timestamp and slug) - Reject top-level 'resume' field on CotrainConfig with a helpful error directing users to per-variant resume fields Improvements: - Rename _find_best_checkpoint → _find_latest_checkpoint with accurate docstring (pretrain/cotrain only write step_* dirs, no best/ symlink) - Populate last_train_acc for cotrain trials in monitor aggregation - Add explicit multiprocessing_context="spawn" to DataLoader in run_cotrain to prevent rayon deadlocks independent of caller context - Fix misleading comment on _hf_push_future.result() — the closure catches all exceptions, so result() blocks but never raises Tests: - Fix test dir names to match actual MetricsLogger format (run_DATE_TIME_VARIANT_SLUG not run_DATE_TIME_SLUG_VARIANT) - Add test for underscore-containing variant names - Add test for top-level resume rejection
1 parent 9bf71c9 commit c4f3744

File tree

11 files changed

+1312
-531
lines changed

11 files changed

+1312
-531
lines changed

pawn/cotrain.py

Lines changed: 612 additions & 0 deletions
Large diffs are not rendered by default.

pawn/lab/monitor.py

Lines changed: 184 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,31 @@ def is_alive(pid: int) -> tuple[bool, int | None]:
4141
def read_metrics(
4242
trial: Trial,
4343
log_dir: Path,
44-
offsets: dict[int, int],
44+
offsets: dict,
4545
) -> None:
46-
"""Read new lines from the trial's metrics.jsonl, updating trial in-place."""
46+
"""Read new lines from the trial's metrics.jsonl, updating trial in-place.
47+
48+
For cotrain trials, discovers multiple per-variant metrics files and
49+
aggregates them to the trial level while tracking per-variant state in
50+
``trial.variants``.
51+
52+
``offsets`` keys are ``int`` (trial_id) for single-variant trials, or
53+
``(trial_id, variant_name)`` for cotrain per-variant files.
54+
"""
55+
is_cotrain = (trial.config or {}).get("run_type") == "cotrain"
56+
57+
if is_cotrain:
58+
_read_cotrain_metrics(trial, log_dir, offsets)
59+
else:
60+
_read_single_metrics(trial, log_dir, offsets)
61+
62+
63+
def _read_single_metrics(
64+
trial: Trial,
65+
log_dir: Path,
66+
offsets: dict,
67+
) -> None:
68+
"""Read metrics for a single-variant (pretrain/adapter) trial."""
4769
# Find run dir if not yet discovered — pick the most recent
4870
if trial.run_dir is None:
4971
trial_log_dir = log_dir / f"trial_{trial.trial_id:04d}"
@@ -116,6 +138,166 @@ def read_metrics(
116138
trial.best_accuracy = acc
117139

118140

141+
def _read_cotrain_metrics(
142+
trial: Trial,
143+
log_dir: Path,
144+
offsets: dict,
145+
) -> None:
146+
"""Read metrics for a cotrain trial (multiple per-variant JSONL files)."""
147+
trial_log_dir = log_dir / f"trial_{trial.trial_id:04d}"
148+
149+
# Discover all per-variant metrics files under the trial dir.
150+
# Each variant's MetricsLogger creates a run dir with suffix=variant_name,
151+
# e.g. run_20260410_151230_zesty-osprey_small/metrics.jsonl
152+
metrics_files = list(trial_log_dir.glob("*/metrics.jsonl"))
153+
if not metrics_files:
154+
return
155+
156+
# Set trial.run_dir to the parent trial dir (not a specific variant)
157+
if trial.run_dir is None:
158+
trial.run_dir = str(trial_log_dir)
159+
160+
# Initialize variants dict if needed
161+
if trial.variants is None:
162+
trial.variants = {}
163+
164+
# Extract variant name from the run dir suffix: run_..._<name>/metrics.jsonl
165+
# The MetricsLogger uses suffix=name, producing dirs like
166+
# run_YYYYMMDD_HHMMSS_slug_variantname/
167+
for mf in metrics_files:
168+
variant_name = _extract_variant_name(mf.parent.name)
169+
if variant_name is None:
170+
continue
171+
172+
# Initialize this variant's state dict
173+
if variant_name not in trial.variants:
174+
trial.variants[variant_name] = {
175+
"name": variant_name,
176+
"run_dir": str(mf.parent),
177+
"current_step": 0,
178+
"last_train_loss": None,
179+
"last_train_acc": None,
180+
"best_val_loss": None,
181+
"best_val_step": 0,
182+
"best_accuracy": None,
183+
"actual_param_count": None,
184+
"stopped": False,
185+
"steps_per_sec": 0.0,
186+
}
187+
188+
vs = trial.variants[variant_name]
189+
offset_key = (trial.trial_id, variant_name)
190+
offset = offsets.get(offset_key, 0)
191+
192+
try:
193+
with open(mf) as f:
194+
f.seek(offset)
195+
new_lines = f.readlines()
196+
offsets[offset_key] = f.tell()
197+
except OSError:
198+
continue
199+
200+
for line in new_lines:
201+
try:
202+
rec = json.loads(line)
203+
except (json.JSONDecodeError, ValueError):
204+
continue
205+
206+
rtype = rec.get("type")
207+
if rtype == "config":
208+
ts = rec.get("total_steps") or (rec.get("training") or {}).get("total_steps")
209+
if ts:
210+
trial.total_steps = ts
211+
pc = rec.get("param_count")
212+
if pc is not None:
213+
vs["actual_param_count"] = pc
214+
215+
elif rtype == "train":
216+
vs["current_step"] = rec.get("step", vs["current_step"])
217+
loss = rec.get("train/loss") or rec.get("train_loss")
218+
if loss is not None:
219+
vs["last_train_loss"] = loss
220+
train_acc = rec.get("train/accuracy") or rec.get("train_top1")
221+
if train_acc is not None:
222+
vs["last_train_acc"] = train_acc
223+
st = rec.get("step_time")
224+
if st and st > 0:
225+
vs["steps_per_sec"] = 1.0 / st
226+
elif rec.get("elapsed") and vs["current_step"] > 0:
227+
vs["steps_per_sec"] = vs["current_step"] / rec["elapsed"]
228+
229+
elif rtype == "val":
230+
vl = rec.get("val/loss") or rec.get("val_loss") or rec.get("loss")
231+
if vl is not None and (vs["best_val_loss"] is None or vl < vs["best_val_loss"]):
232+
vs["best_val_loss"] = vl
233+
vs["best_val_step"] = rec.get("step", vs.get("best_val_step", 0))
234+
acc = (rec.get("val/accuracy") or rec.get("val_top1")
235+
or rec.get("accuracy"))
236+
if acc is not None:
237+
vs["best_accuracy"] = acc
238+
239+
# Aggregate to trial level
240+
_aggregate_cotrain_metrics(trial)
241+
242+
243+
def _extract_variant_name(run_dir_name: str) -> str | None:
244+
"""Extract variant name from a run directory name.
245+
246+
The MetricsLogger creates dirs like ``run_YYYYMMDD_HHMMSS_variantname_slug``.
247+
The layout is: ``run`` _ ``date`` _ ``time`` _ ``variant`` _ ``slug``.
248+
The variant name may itself contain underscores, but the slug (final segment)
249+
never does (it's two hyphenated words like ``calm-crane``). So we rejoin
250+
everything between parts[3] and parts[-1].
251+
"""
252+
# Expected: run_YYYYMMDD_HHMMSS_variant_slug (at least 5 parts)
253+
parts = run_dir_name.split("_")
254+
if len(parts) < 5 or parts[0] != "run":
255+
return None
256+
# parts[1]=date, parts[2]=time, parts[-1]=slug, parts[3:-1]=variant
257+
return "_".join(parts[3:-1])
258+
259+
260+
def _aggregate_cotrain_metrics(trial: Trial) -> None:
261+
"""Aggregate per-variant metrics to the trial level."""
262+
if not trial.variants:
263+
return
264+
265+
variants = list(trial.variants.values())
266+
267+
# current_step = min across variants (honest ETA — slowest determines progress)
268+
steps = [v["current_step"] for v in variants if v["current_step"] > 0]
269+
if steps:
270+
trial.current_step = min(steps)
271+
272+
# best_val_loss = min across variants
273+
val_losses = [v["best_val_loss"] for v in variants if v["best_val_loss"] is not None]
274+
if val_losses:
275+
trial.best_val_loss = min(val_losses)
276+
277+
# best_accuracy = max across variants
278+
accs = [v["best_accuracy"] for v in variants if v["best_accuracy"] is not None]
279+
if accs:
280+
trial.best_accuracy = max(accs)
281+
282+
# last_train_loss = mean across active variants
283+
losses = [v["last_train_loss"] for v in variants
284+
if v["last_train_loss"] is not None and not v.get("stopped")]
285+
if losses:
286+
trial.last_train_loss = sum(losses) / len(losses)
287+
288+
# last_train_acc = mean across active variants
289+
accs_train = [v["last_train_acc"] for v in variants
290+
if v.get("last_train_acc") is not None and not v.get("stopped")]
291+
if accs_train:
292+
trial.last_train_acc = sum(accs_train) / len(accs_train)
293+
294+
# steps_per_sec from any variant (they share the same step timing)
295+
for v in variants:
296+
if v.get("steps_per_sec", 0) > 0:
297+
trial.steps_per_sec = v["steps_per_sec"]
298+
break
299+
300+
119301
def read_pretrain_val_summary(trial: Trial) -> dict[str, Any] | None:
120302
"""Scan the trial's metrics.jsonl for the latest pretraining val record
121303
and compute a log-linear fit on forfeit rate over the most recent half

pawn/lab/runner.py

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@ def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
3535
"""
3636
from pydantic import TypeAdapter
3737

38-
from pawn.run_config import AdapterConfig, PretrainConfig
38+
from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig
3939

4040
run_type = config.get("run_type")
41-
if run_type not in ("pretrain", "adapter"):
42-
raise ValueError(f"run_type must be 'pretrain' or 'adapter', got {run_type!r}")
43-
ta = TypeAdapter(
44-
PretrainConfig if run_type == "pretrain" else AdapterConfig
45-
)
41+
config_cls = {
42+
"pretrain": PretrainConfig,
43+
"adapter": AdapterConfig,
44+
"cotrain": CotrainConfig,
45+
}.get(run_type) # type: ignore[arg-type]
46+
if config_cls is None:
47+
raise ValueError(
48+
f"run_type must be 'pretrain', 'adapter', or 'cotrain', got {run_type!r}"
49+
)
50+
ta = TypeAdapter(config_cls)
4651
return ta.validate_python(config).model_dump()
4752

4853

@@ -266,7 +271,11 @@ async def launch(
266271
validated = _validate_config(config)
267272
cmd = self._build_command(validated, trial_id)
268273

269-
strategy_display = validated.get("strategy") or validated.get("variant", "pretrain")
274+
if validated.get("run_type") == "cotrain":
275+
variant_names = [v["name"] for v in validated.get("variants", [])]
276+
strategy_display = "cotrain:" + "+".join(variant_names)
277+
else:
278+
strategy_display = validated.get("strategy") or validated.get("variant", "pretrain")
270279
trial = Trial(
271280
trial_id=trial_id,
272281
strategy=strategy_display,
@@ -296,34 +305,83 @@ async def resume_trial(
296305
total_steps: int | None = None,
297306
pause_after_steps: int | None = None,
298307
) -> int:
299-
"""Resume a completed/failed trial from its best checkpoint."""
308+
"""Resume a completed/failed trial from its best checkpoint.
309+
310+
For cotrain trials, discovers per-variant checkpoints and sets
311+
the resume path on each variant in the new config.
312+
"""
300313
old = self.trials.get(trial_id)
301314
if not old:
302315
raise RuntimeError(f"Trial {trial_id} not found")
303316
if not old.run_dir:
304317
raise RuntimeError(f"Trial {trial_id} has no run directory")
305318

306-
ckpt_base = Path(old.run_dir) / "checkpoints"
319+
new_config = dict(old.config)
320+
new_config.pop("pause_after_steps", None)
321+
322+
if (old.config or {}).get("run_type") == "cotrain":
323+
self._resolve_cotrain_resume(old, new_config)
324+
else:
325+
ckpt_dir = self._find_latest_checkpoint(Path(old.run_dir))
326+
new_config["resume"] = str(ckpt_dir)
327+
328+
if total_steps is not None:
329+
new_config["total_steps"] = total_steps
330+
if pause_after_steps is not None:
331+
new_config["pause_after_steps"] = pause_after_steps
332+
333+
return await self.launch(new_config, tags=old.tags)
334+
335+
@staticmethod
336+
def _find_latest_checkpoint(run_dir: Path) -> Path:
337+
"""Find the latest checkpoint under a run directory.
338+
339+
Checks for ``best/`` and ``final/`` symlinks first (adapter runs),
340+
then falls back to the highest-numbered ``step_*`` directory
341+
(pretrain/cotrain runs, which don't create best/final symlinks).
342+
"""
343+
ckpt_base = run_dir / "checkpoints"
307344
ckpt_dir = ckpt_base / "best"
308345
if not ckpt_dir.exists():
309346
ckpt_dir = ckpt_base / "final"
310347
if not ckpt_dir.exists():
311-
# Pretraining uses step_XXXXXXXX naming — pick the highest step
312348
step_dirs = sorted(ckpt_base.glob("step_*"))
313349
if step_dirs:
314350
ckpt_dir = step_dirs[-1]
315351
if not ckpt_dir.exists():
316-
raise RuntimeError(f"No checkpoint found for trial {trial_id}")
352+
raise RuntimeError(f"No checkpoint found under {run_dir}")
353+
return ckpt_dir
317354

318-
new_config = dict(old.config)
319-
new_config.pop("pause_after_steps", None)
320-
new_config["resume"] = str(ckpt_dir)
321-
if total_steps is not None:
322-
new_config["total_steps"] = total_steps
323-
if pause_after_steps is not None:
324-
new_config["pause_after_steps"] = pause_after_steps
355+
def _resolve_cotrain_resume(
356+
self, old: "Trial", new_config: dict[str, Any],
357+
) -> None:
358+
"""Set per-variant resume paths for a cotrain trial."""
359+
if not old.variants:
360+
raise RuntimeError(
361+
f"Trial {old.trial_id} is cotrain but has no variant state. "
362+
"Cannot determine per-variant checkpoints."
363+
)
325364

326-
return await self.launch(new_config, tags=old.tags)
365+
# Deep-copy variants list so we can mutate
366+
import copy
367+
variants = copy.deepcopy(new_config.get("variants", []))
368+
369+
for v_cfg in variants:
370+
name = v_cfg.get("name")
371+
if name not in old.variants:
372+
raise RuntimeError(
373+
f"Variant '{name}' not found in trial {old.trial_id} state"
374+
)
375+
vs = old.variants[name]
376+
v_run_dir = vs.get("run_dir")
377+
if not v_run_dir:
378+
raise RuntimeError(
379+
f"Variant '{name}' in trial {old.trial_id} has no run directory"
380+
)
381+
ckpt_dir = self._find_latest_checkpoint(Path(v_run_dir))
382+
v_cfg["resume"] = str(ckpt_dir)
383+
384+
new_config["variants"] = variants
327385

328386
def _build_command(
329387
self, config: dict[str, Any], trial_id: int,

pawn/lab/server.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def lab_status(ctx: Context) -> dict[str, Any]:
4444

4545
@mcp.tool
4646
async def lab_launch(config: dict[str, Any], ctx: Context, tags: list[str] | None = None) -> dict[str, Any]:
47-
"""Launch a trial from a RunConfig dict. Use lab_schema to discover all fields. The config must include run_type ('pretrain' or 'adapter'). Optionally pass tags for grouping (e.g. ["phase1", "mate-boost"])."""
47+
"""Launch a trial from a RunConfig dict. Use lab_schema to discover all fields. The config must include run_type ('pretrain', 'adapter', or 'cotrain'). Optionally pass tags for grouping (e.g. ["phase1", "mate-boost"])."""
4848
try:
4949
tid = await _runner(ctx).launch(config, tags=tags)
5050
return _runner(ctx).trials[tid].to_dict()
@@ -105,10 +105,11 @@ async def lab_set_cost(cost_per_hour: float, ctx: Context) -> dict[str, Any]:
105105

106106
@mcp.tool
107107
async def lab_schema(ctx: Context) -> dict[str, Any]:
108-
"""Return the JSON Schema for RunConfig (PretrainConfig and AdapterConfig). Use this to discover all available parameters before calling lab_launch."""
109-
from pawn.run_config import AdapterConfig, PretrainConfig
108+
"""Return the JSON Schema for RunConfig (PretrainConfig, AdapterConfig, CotrainConfig). Use this to discover all available parameters before calling lab_launch."""
109+
from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig
110110

111111
return {
112112
"pretrain": PretrainConfig.model_json_schema(),
113113
"adapter": AdapterConfig.model_json_schema(),
114+
"cotrain": CotrainConfig.model_json_schema(),
114115
}

pawn/lab/state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class Trial:
5252
# Agent annotations
5353
notes: str = ""
5454
tags: list[str] = field(default_factory=list)
55+
# Co-training: per-variant state (None for non-cotrain trials)
56+
variants: dict[str, dict[str, Any]] | None = None
5557

5658
def eta_seconds(self) -> float | None:
5759
if self.steps_per_sec > 0 and self.total_steps > self.current_step:

0 commit comments

Comments
 (0)