You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adds a `ModelProtocol.get_extra_metrics` method for more flexible custom
metric reporting, as discussed in #1576
Probably this should be an abstract method, but I was wary of making
this a breaking change for users who inherit this commit.
The current signature is `get_extra_metrics(self, parallel_dims:
ParallelDims) -> None | dict`. I also considered adding some subset of
`JobConfig`, `TrainSpec`, and `pp_has_{first,last}_stage`; not sure what
else might be useful.
Tested via running the debugmodel with print statements.
CC @rakkit@wwwjn
0 commit comments