Skip to content

Commit cf30b29

Browse files
authored
Add logging for learning rates in MetricsProcessor (#1413)
This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](#937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](#938). <img width="1842" height="730" alt="image" src="https://github.com/user-attachments/assets/8f23674a-d689-4cc2-9d9b-30bff4e63f3b" /> One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness.
1 parent b1dc330 commit cf30b29

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchtitan/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def train_step(
456456
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
457457
):
458458
self.optimizers.zero_grad()
459+
# Save the current step learning rate for logging
460+
lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]
459461

460462
# Keep these variables local to shorten the code as these are
461463
# the major variables that are used in the training loop.
@@ -503,12 +505,16 @@ def train_step(
503505
else:
504506
global_avg_loss = global_max_loss = loss.detach().item()
505507

508+
extra_metrics = {
509+
"n_tokens_seen": self.ntokens_seen,
510+
"lr": lr,
511+
}
506512
self.metrics_processor.log(
507513
self.step,
508514
global_avg_loss,
509515
global_max_loss,
510516
grad_norm.item(),
511-
extra_metrics={"ntokens_seen": self.ntokens_seen},
517+
extra_metrics=extra_metrics,
512518
)
513519

514520
@record

0 commit comments

Comments
 (0)