Skip to content

Commit a0fdaa3

Browse files
authored
All-reduce ntokens_seen before logging (#1509)
Currently, `ntokens_seen` is only locally logged. I think it is almost always desirable to only track the global quantity (the only use case I can see for per-device tracking is for debugging?). Therefore, I propose to all-reduce `ntokens_seen` before logging.
1 parent 48d8dcd commit a0fdaa3

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

torchtitan/distributed/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ def dist_max(
6262
)
6363

6464

65+
def dist_sum(
66+
x: torch.Tensor,
67+
mesh: DeviceMesh,
68+
extra_pg: dist.ProcessGroup | None = None,
69+
) -> float:
70+
return _dist_reduce(
71+
x, reduceOp=c10d.ReduceOp.SUM.name, mesh=mesh, extra_pg=extra_pg
72+
)
73+
74+
6575
def dist_mean(
6676
x: torch.Tensor,
6777
mesh: DeviceMesh,

torchtitan/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,15 +498,23 @@ def train_step(
498498
if parallel_dims.dp_cp_enabled:
499499
loss = loss.detach()
500500
ft_pg = self.ft_manager.loss_sync_pg
501-
global_avg_loss, global_max_loss = (
501+
global_avg_loss, global_max_loss, global_ntokens_seen = (
502502
dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
503503
dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
504+
dist_utils.dist_sum(
505+
torch.tensor(
506+
self.ntokens_seen, dtype=torch.int64, device=self.device
507+
),
508+
parallel_dims.world_mesh["dp_cp"],
509+
ft_pg,
510+
),
504511
)
505512
else:
506513
global_avg_loss = global_max_loss = loss.detach().item()
514+
global_ntokens_seen = self.ntokens_seen
507515

508516
extra_metrics = {
509-
"n_tokens_seen": self.ntokens_seen,
517+
"n_tokens_seen": global_ntokens_seen,
510518
"lr": lr,
511519
}
512520
self.metrics_processor.log(

0 commit comments

Comments
 (0)