From 611a500ea76e1704425463a9d6dd30bcefdd7c58 Mon Sep 17 00:00:00 2001 From: Yifan Luo Date: Thu, 3 Oct 2024 19:16:03 +1000 Subject: [PATCH] Fix comment typos in auto_unit.py The old comments have "PredictUnit" in `get_next_train_batch`, `get_next_eval_batch`, which could be a copy-pasta typo from `get_next_predict_batch`. The new comments change from "PredictUnit" to "TrainUnit" in `get_next_train_batch` and "EvalUnit" in `get_next_eval_batch` --- torchtnt/framework/auto_unit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 38a50157c6..46b72f686b 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -876,7 +876,7 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]: def get_next_train_batch( self, state: State, data_iter: Iterator[TData] ) -> Union[Iterator[TData], TData]: - # Override the default behavior from PredictUnit in order to enable prefetching if possible. + # Override the default behavior from TrainUnit in order to enable prefetching if possible. pass_data_iter_to_step = _step_requires_iterator(self.train_step) if pass_data_iter_to_step: return data_iter @@ -885,7 +885,7 @@ def get_next_train_batch( def get_next_eval_batch( self, state: State, data_iter: Iterator[TData] ) -> Union[Iterator[TData], TData]: - # Override the default behavior from PredictUnit in order to enable prefetching if possible. + # Override the default behavior from EvalUnit in order to enable prefetching if possible. pass_data_iter_to_step = _step_requires_iterator(self.eval_step) if pass_data_iter_to_step: return data_iter