Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
612 changes: 612 additions & 0 deletions pawn/cotrain.py

Large diffs are not rendered by default.

186 changes: 184 additions & 2 deletions pawn/lab/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,31 @@ def is_alive(pid: int) -> tuple[bool, int | None]:
def read_metrics(
trial: Trial,
log_dir: Path,
offsets: dict[int, int],
offsets: dict,
) -> None:
"""Read new lines from the trial's metrics.jsonl, updating trial in-place."""
"""Read new lines from the trial's metrics.jsonl, updating trial in-place.

For cotrain trials, discovers multiple per-variant metrics files and
aggregates them to the trial level while tracking per-variant state in
``trial.variants``.

``offsets`` keys are ``int`` (trial_id) for single-variant trials, or
``(trial_id, variant_name)`` for cotrain per-variant files.
"""
is_cotrain = (trial.config or {}).get("run_type") == "cotrain"

if is_cotrain:
_read_cotrain_metrics(trial, log_dir, offsets)
else:
_read_single_metrics(trial, log_dir, offsets)


def _read_single_metrics(
trial: Trial,
log_dir: Path,
offsets: dict,
) -> None:
"""Read metrics for a single-variant (pretrain/adapter) trial."""
# Find run dir if not yet discovered — pick the most recent
if trial.run_dir is None:
trial_log_dir = log_dir / f"trial_{trial.trial_id:04d}"
Expand Down Expand Up @@ -116,6 +138,166 @@ def read_metrics(
trial.best_accuracy = acc


def _read_cotrain_metrics(
trial: Trial,
log_dir: Path,
offsets: dict,
) -> None:
"""Read metrics for a cotrain trial (multiple per-variant JSONL files)."""
trial_log_dir = log_dir / f"trial_{trial.trial_id:04d}"

# Discover all per-variant metrics files under the trial dir.
# Each variant's MetricsLogger creates a run dir with suffix=variant_name,
# e.g. run_20260410_151230_zesty-osprey_small/metrics.jsonl
metrics_files = list(trial_log_dir.glob("*/metrics.jsonl"))
if not metrics_files:
return

# Set trial.run_dir to the parent trial dir (not a specific variant)
if trial.run_dir is None:
trial.run_dir = str(trial_log_dir)

# Initialize variants dict if needed
if trial.variants is None:
trial.variants = {}

# Extract variant name from the run dir suffix: run_..._<name>/metrics.jsonl
# The MetricsLogger uses suffix=name, producing dirs like
# run_YYYYMMDD_HHMMSS_slug_variantname/
for mf in metrics_files:
variant_name = _extract_variant_name(mf.parent.name)
if variant_name is None:
continue

# Initialize this variant's state dict
if variant_name not in trial.variants:
trial.variants[variant_name] = {
"name": variant_name,
"run_dir": str(mf.parent),
"current_step": 0,
"last_train_loss": None,
"last_train_acc": None,
"best_val_loss": None,
"best_val_step": 0,
"best_accuracy": None,
"actual_param_count": None,
"stopped": False,
"steps_per_sec": 0.0,
}

vs = trial.variants[variant_name]
offset_key = (trial.trial_id, variant_name)
offset = offsets.get(offset_key, 0)

try:
with open(mf) as f:
f.seek(offset)
new_lines = f.readlines()
offsets[offset_key] = f.tell()
except OSError:
continue

for line in new_lines:
try:
rec = json.loads(line)
except (json.JSONDecodeError, ValueError):
continue

rtype = rec.get("type")
if rtype == "config":
ts = rec.get("total_steps") or (rec.get("training") or {}).get("total_steps")
if ts:
trial.total_steps = ts
pc = rec.get("param_count")
if pc is not None:
vs["actual_param_count"] = pc

elif rtype == "train":
vs["current_step"] = rec.get("step", vs["current_step"])
loss = rec.get("train/loss") or rec.get("train_loss")
if loss is not None:
vs["last_train_loss"] = loss
train_acc = rec.get("train/accuracy") or rec.get("train_top1")
if train_acc is not None:
vs["last_train_acc"] = train_acc
st = rec.get("step_time")
if st and st > 0:
vs["steps_per_sec"] = 1.0 / st
elif rec.get("elapsed") and vs["current_step"] > 0:
vs["steps_per_sec"] = vs["current_step"] / rec["elapsed"]

elif rtype == "val":
vl = rec.get("val/loss") or rec.get("val_loss") or rec.get("loss")
if vl is not None and (vs["best_val_loss"] is None or vl < vs["best_val_loss"]):
vs["best_val_loss"] = vl
vs["best_val_step"] = rec.get("step", vs.get("best_val_step", 0))
acc = (rec.get("val/accuracy") or rec.get("val_top1")
or rec.get("accuracy"))
if acc is not None:
vs["best_accuracy"] = acc

# Aggregate to trial level
_aggregate_cotrain_metrics(trial)


def _extract_variant_name(run_dir_name: str) -> str | None:
"""Extract variant name from a run directory name.

The MetricsLogger creates dirs like ``run_YYYYMMDD_HHMMSS_variantname_slug``.
The layout is: ``run`` _ ``date`` _ ``time`` _ ``variant`` _ ``slug``.
The variant name may itself contain underscores, but the slug (final segment)
never does (it's two hyphenated words like ``calm-crane``). So we rejoin
everything between parts[3] and parts[-1].
"""
# Expected: run_YYYYMMDD_HHMMSS_variant_slug (at least 5 parts)
parts = run_dir_name.split("_")
if len(parts) < 5 or parts[0] != "run":
return None
# parts[1]=date, parts[2]=time, parts[-1]=slug, parts[3:-1]=variant
return "_".join(parts[3:-1])


def _aggregate_cotrain_metrics(trial: Trial) -> None:
"""Aggregate per-variant metrics to the trial level."""
if not trial.variants:
return

variants = list(trial.variants.values())

# current_step = min across variants (honest ETA — slowest determines progress)
steps = [v["current_step"] for v in variants if v["current_step"] > 0]
if steps:
trial.current_step = min(steps)

# best_val_loss = min across variants
val_losses = [v["best_val_loss"] for v in variants if v["best_val_loss"] is not None]
if val_losses:
trial.best_val_loss = min(val_losses)

# best_accuracy = max across variants
accs = [v["best_accuracy"] for v in variants if v["best_accuracy"] is not None]
if accs:
trial.best_accuracy = max(accs)

# last_train_loss = mean across active variants
losses = [v["last_train_loss"] for v in variants
if v["last_train_loss"] is not None and not v.get("stopped")]
if losses:
trial.last_train_loss = sum(losses) / len(losses)

# last_train_acc = mean across active variants
accs_train = [v["last_train_acc"] for v in variants
if v.get("last_train_acc") is not None and not v.get("stopped")]
if accs_train:
trial.last_train_acc = sum(accs_train) / len(accs_train)

# steps_per_sec from any variant (they share the same step timing)
for v in variants:
if v.get("steps_per_sec", 0) > 0:
trial.steps_per_sec = v["steps_per_sec"]
break


def read_pretrain_val_summary(trial: Trial) -> dict[str, Any] | None:
"""Scan the trial's metrics.jsonl for the latest pretraining val record
and compute a log-linear fit on forfeit rate over the most recent half
Expand Down
96 changes: 77 additions & 19 deletions pawn/lab/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
"""
from pydantic import TypeAdapter

from pawn.run_config import AdapterConfig, PretrainConfig
from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig

run_type = config.get("run_type")
if run_type not in ("pretrain", "adapter"):
raise ValueError(f"run_type must be 'pretrain' or 'adapter', got {run_type!r}")
ta = TypeAdapter(
PretrainConfig if run_type == "pretrain" else AdapterConfig
)
config_cls = {
"pretrain": PretrainConfig,
"adapter": AdapterConfig,
"cotrain": CotrainConfig,
}.get(run_type) # type: ignore[arg-type]
if config_cls is None:
raise ValueError(
f"run_type must be 'pretrain', 'adapter', or 'cotrain', got {run_type!r}"
)
ta = TypeAdapter(config_cls)
return ta.validate_python(config).model_dump()


Expand Down Expand Up @@ -266,7 +271,11 @@ async def launch(
validated = _validate_config(config)
cmd = self._build_command(validated, trial_id)

strategy_display = validated.get("strategy") or validated.get("variant", "pretrain")
if validated.get("run_type") == "cotrain":
variant_names = [v["name"] for v in validated.get("variants", [])]
strategy_display = "cotrain:" + "+".join(variant_names)
else:
strategy_display = validated.get("strategy") or validated.get("variant", "pretrain")
trial = Trial(
trial_id=trial_id,
strategy=strategy_display,
Expand Down Expand Up @@ -296,34 +305,83 @@ async def resume_trial(
total_steps: int | None = None,
pause_after_steps: int | None = None,
) -> int:
"""Resume a completed/failed trial from its best checkpoint."""
"""Resume a completed/failed trial from its best checkpoint.

For cotrain trials, discovers per-variant checkpoints and sets
the resume path on each variant in the new config.
"""
old = self.trials.get(trial_id)
if not old:
raise RuntimeError(f"Trial {trial_id} not found")
if not old.run_dir:
raise RuntimeError(f"Trial {trial_id} has no run directory")

ckpt_base = Path(old.run_dir) / "checkpoints"
new_config = dict(old.config)
new_config.pop("pause_after_steps", None)

if (old.config or {}).get("run_type") == "cotrain":
self._resolve_cotrain_resume(old, new_config)
else:
ckpt_dir = self._find_latest_checkpoint(Path(old.run_dir))
new_config["resume"] = str(ckpt_dir)

if total_steps is not None:
new_config["total_steps"] = total_steps
if pause_after_steps is not None:
new_config["pause_after_steps"] = pause_after_steps

return await self.launch(new_config, tags=old.tags)

@staticmethod
def _find_latest_checkpoint(run_dir: Path) -> Path:
"""Find the latest checkpoint under a run directory.

Checks for ``best/`` and ``final/`` symlinks first (adapter runs),
then falls back to the highest-numbered ``step_*`` directory
(pretrain/cotrain runs, which don't create best/final symlinks).
"""
ckpt_base = run_dir / "checkpoints"
ckpt_dir = ckpt_base / "best"
if not ckpt_dir.exists():
ckpt_dir = ckpt_base / "final"
if not ckpt_dir.exists():
# Pretraining uses step_XXXXXXXX naming — pick the highest step
step_dirs = sorted(ckpt_base.glob("step_*"))
if step_dirs:
ckpt_dir = step_dirs[-1]
if not ckpt_dir.exists():
raise RuntimeError(f"No checkpoint found for trial {trial_id}")
raise RuntimeError(f"No checkpoint found under {run_dir}")
return ckpt_dir

new_config = dict(old.config)
new_config.pop("pause_after_steps", None)
new_config["resume"] = str(ckpt_dir)
if total_steps is not None:
new_config["total_steps"] = total_steps
if pause_after_steps is not None:
new_config["pause_after_steps"] = pause_after_steps
def _resolve_cotrain_resume(
self, old: "Trial", new_config: dict[str, Any],
) -> None:
"""Set per-variant resume paths for a cotrain trial."""
if not old.variants:
raise RuntimeError(
f"Trial {old.trial_id} is cotrain but has no variant state. "
"Cannot determine per-variant checkpoints."
)

return await self.launch(new_config, tags=old.tags)
# Deep-copy variants list so we can mutate
import copy
variants = copy.deepcopy(new_config.get("variants", []))

for v_cfg in variants:
name = v_cfg.get("name")
if name not in old.variants:
raise RuntimeError(
f"Variant '{name}' not found in trial {old.trial_id} state"
)
vs = old.variants[name]
v_run_dir = vs.get("run_dir")
if not v_run_dir:
raise RuntimeError(
f"Variant '{name}' in trial {old.trial_id} has no run directory"
)
ckpt_dir = self._find_latest_checkpoint(Path(v_run_dir))
v_cfg["resume"] = str(ckpt_dir)

new_config["variants"] = variants

def _build_command(
self, config: dict[str, Any], trial_id: int,
Expand Down
7 changes: 4 additions & 3 deletions pawn/lab/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def lab_status(ctx: Context) -> dict[str, Any]:

@mcp.tool
async def lab_launch(config: dict[str, Any], ctx: Context, tags: list[str] | None = None) -> dict[str, Any]:
"""Launch a trial from a RunConfig dict. Use lab_schema to discover all fields. The config must include run_type ('pretrain' or 'adapter'). Optionally pass tags for grouping (e.g. ["phase1", "mate-boost"])."""
"""Launch a trial from a RunConfig dict. Use lab_schema to discover all fields. The config must include run_type ('pretrain', 'adapter', or 'cotrain'). Optionally pass tags for grouping (e.g. ["phase1", "mate-boost"])."""
try:
tid = await _runner(ctx).launch(config, tags=tags)
return _runner(ctx).trials[tid].to_dict()
Expand Down Expand Up @@ -105,10 +105,11 @@ async def lab_set_cost(cost_per_hour: float, ctx: Context) -> dict[str, Any]:

@mcp.tool
async def lab_schema(ctx: Context) -> dict[str, Any]:
"""Return the JSON Schema for RunConfig (PretrainConfig and AdapterConfig). Use this to discover all available parameters before calling lab_launch."""
from pawn.run_config import AdapterConfig, PretrainConfig
"""Return the JSON Schema for RunConfig (PretrainConfig, AdapterConfig, CotrainConfig). Use this to discover all available parameters before calling lab_launch."""
from pawn.run_config import AdapterConfig, CotrainConfig, PretrainConfig

return {
"pretrain": PretrainConfig.model_json_schema(),
"adapter": AdapterConfig.model_json_schema(),
"cotrain": CotrainConfig.model_json_schema(),
}
2 changes: 2 additions & 0 deletions pawn/lab/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Trial:
# Agent annotations
notes: str = ""
tags: list[str] = field(default_factory=list)
# Co-training: per-variant state (None for non-cotrain trials)
variants: dict[str, dict[str, Any]] | None = None

def eta_seconds(self) -> float | None:
if self.steps_per_sec > 0 and self.total_steps > self.current_step:
Expand Down
Loading
Loading