diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b69fd679..434e1fae7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -466,16 +466,9 @@ def forward_backward_step( return loss - def train_step( - self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] - ): - self.optimizers.zero_grad() - # Save the current step learning rate for logging - lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] + def gradient_computation(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]) -> tuple[list[torch.Tensor], torch.Tensor]: - # Keep these variables local to shorten the code as these are - # the major variables that are used in the training loop. - parallel_dims = self.parallel_dims + self.optimizers.zero_grad() accumulated_losses = [] # If data runs out during gradient accumulation, that @@ -490,32 +483,33 @@ def train_step( self.job_config.training.max_norm, foreach=True, pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + self.parallel_dims.world_mesh["pp"] if self.parallel_dims.pp_enabled else None ), - ep_enabled=parallel_dims.ep_enabled, + ep_enabled=self.parallel_dims.ep_enabled, ) - self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) - # log metrics - if not self.metrics_processor.should_log(self.step): - return + return loss, grad_norm + + def run_optimizer_step(self) -> None: + self.checkpointer.maybe_wait_for_staging() + self.optimizers.step() + self.lr_schedulers.step() - if parallel_dims.dp_cp_enabled: + def compute_global_loss(self, loss: torch.Tensor) -> tuple[float, float, int]: + if self.parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, self.parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_max(loss, self.parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), - parallel_dims.world_mesh["dp_cp"], + self.parallel_dims.world_mesh["dp_cp"], ft_pg, ), ) @@ -523,15 +517,36 @@ def train_step( global_avg_loss = global_max_loss = loss.detach().item() global_ntokens_seen = self.ntokens_seen + return global_avg_loss, global_max_loss, global_ntokens_seen + + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + + # Run the gradient computation + loss, grad_norm = self.gradient_computation(data_iterator=data_iterator) + + # Save the current step learning rate for logging + lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] + + # Run the optimizer step + self.run_optimizer_step() + + # log metrics + if not self.metrics_processor.should_log(self.step): + return + + global_avg_loss, global_max_loss, global_ntokens_seen = self.compute_global_loss(loss=loss) + extra_metrics = { "n_tokens_seen": global_ntokens_seen, "lr": lr, } self.metrics_processor.log( - self.step, - global_avg_loss, - global_max_loss, - grad_norm.item(), + step=self.step, + global_avg_loss=global_avg_loss, + global_max_loss=global_max_loss, + grad_norm=grad_norm.item(), extra_metrics=extra_metrics, ) @@ -578,7 +593,7 @@ def train(self): ), ): data_iterator = self.batch_generator(self.dataloader) - while self.step < job_config.training.steps: + while self.should_continue_training(): self.step += 1 self.gc_handler.run(self.step) try: @@ -620,6 +635,9 @@ def train(self): logger.info("Training completed") + def should_continue_training(self) -> bool: + return self.step < self.job_config.training.steps + def state_dict(self) -> dict[str, Any]: return {"step": self.step, "ntokens_seen": self.ntokens_seen}