diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 38a50157c6..46b72f686b 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -876,7 +876,7 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]: def get_next_train_batch( self, state: State, data_iter: Iterator[TData] ) -> Union[Iterator[TData], TData]: - # Override the default behavior from PredictUnit in order to enable prefetching if possible. + # Override the default behavior from TrainUnit in order to enable prefetching if possible. pass_data_iter_to_step = _step_requires_iterator(self.train_step) if pass_data_iter_to_step: return data_iter @@ -885,7 +885,7 @@ def get_next_train_batch( def get_next_eval_batch( self, state: State, data_iter: Iterator[TData] ) -> Union[Iterator[TData], TData]: - # Override the default behavior from PredictUnit in order to enable prefetching if possible. + # Override the default behavior from EvalUnit in order to enable prefetching if possible. pass_data_iter_to_step = _step_requires_iterator(self.eval_step) if pass_data_iter_to_step: return data_iter