Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tinker_cookbook/recipes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ Our examples support the following CLI arguments to log the results.
- `{log_path}/metrics.jsonl` saves training metrics.
- `{log_path}/checkpoints.jsonl` records all the checkpoints saved during training. You can share these checkpoints for model release, offline evaluation, etc.
- Resuming: When using an existing `log_path`, you can either overwrite the previous run or resume training. This is particularly useful for recovering from runtime interruptions.
3. **Resuming W&B Runs**: To resume logging to an existing W&B run (e.g., after a training interruption), set the following environment variables before running your training script:
```bash
export WANDB_RUN_ID=<run_id>
export WANDB_RESUME=must
```
Replace `<run_id>` with the ID of the W&B run you want to resume. W&B will automatically pick up these variables and continue logging to the same run.
41 changes: 40 additions & 1 deletion tinker_cookbook/utils/ml_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,43 @@ def __init__(
name=wandb_name,
)

self.step_offset = 0

is_resumed = os.environ.get("WANDB_RUN_ID") is not None
if is_resumed:
assert self.run.id == os.environ.get("WANDB_RUN_ID"), (
"WANDB_RUN_ID does not match run ID"
)
self.step_offset = self._get_last_step_from_run()
print(f"Resumed WandB run {self.run.id}, step offset: {self.step_offset}")

def _get_last_step_from_run(self) -> int:
"""Get the last logged step from a resumed WandB run."""
assert wandb is not None
if self.run is None:
return 0

try:
api = wandb.Api()
run = api.run(f"{self.run.entity}/{self.run.project}/{self.run.id}")

# Try lastHistoryStep first
if hasattr(run, "lastHistoryStep") and run.lastHistoryStep is not None:
print(f"Detected last step {run.lastHistoryStep} from lastHistoryStep")
return run.lastHistoryStep + 1

# Fallback: try to get from run summary
if run.summary and "_step" in run.summary:
step_value = run.summary["_step"]
if isinstance(step_value, (int, float)):
print(f"Detected last step {int(step_value)} from summary")
return int(step_value) + 1

except Exception as e:
print(f"WARNING: Could not determine last step from WandB run: {e}")

raise ValueError("Couldn't determine step offset to resume WandB run")

def log_hparams(self, config: Any) -> None:
"""Log hyperparameters to wandb."""
if self.run and wandb is not None:
Expand All @@ -234,7 +271,9 @@ def log_hparams(self, config: Any) -> None:
def log_metrics(self, metrics: Dict[str, Any], step: int | None = None) -> None:
"""Log metrics to wandb."""
if self.run and wandb is not None:
wandb.log(metrics, step=step)
# Apply step offset for resumed runs
effective_step = step + self.step_offset if step is not None else None
wandb.log(metrics, step=effective_step)
logger.info("Logging to: %s", self.run.url)

def close(self) -> None:
Expand Down