Skip to content

Commit 7d744b2

Browse files
authored
add model_parts ref to MetricsProcessor (#1578)
Adds a `ModelProtocol.get_extra_metrics` method for more flexible custom metric reporting, as discussed in #1576 Probably this should be an abstract method, but I was wary of making this a breaking change for users who inherit this commit. The current signature is `get_extra_metrics(self, parallel_dims: ParallelDims) -> None | dict`. I also considered adding some subset of `JobConfig`, `TrainSpec`, and `pp_has_{first,last}_stage`; not sure what else might be useful. Tested via running the debugmodel with print statements. CC @rakkit @wwwjn
1 parent 255a6ab commit 7d744b2

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

torchtitan/components/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ class MetricsProcessor:
319319
num_flops_per_token: int
320320
optimizers: OptimizersContainer | None
321321
lr_schedulers: LRSchedulersContainer | None
322+
model_parts: list[torch.nn.Module] | None
322323

323324
def __init__(
324325
self,
@@ -349,6 +350,7 @@ def __init__(
349350
self.num_flops_per_token = -1
350351
self.optimizers = None
351352
self.lr_schedulers = None
353+
self.model_parts = None
352354

353355
def should_log(self, step: int) -> bool:
354356
return step == 1 or step % self.job_config.metrics.log_freq == 0

torchtitan/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def __init__(self, job_config: JobConfig):
294294
)
295295
)
296296
self.metrics_processor.optimizers = self.optimizers
297+
self.metrics_processor.model_parts = self.model_parts
297298

298299
# Initialize trainer states that will be saved in checkpoint.
299300
# These attributes must be initialized before checkpoint loading.

0 commit comments

Comments
 (0)