File tree Expand file tree Collapse file tree 3 files changed +14
-8
lines changed Expand file tree Collapse file tree 3 files changed +14
-8
lines changed Original file line number Diff line number Diff line change 16
16
from torchdata .stateful_dataloader import StatefulDataLoader
17
17
from torchtitan .tools .logging import logger
18
18
19
-
20
- class DataloaderStopIteration (StopIteration ):
19
+ # NOTE: This class deliberately inherits from `Exception` and not `StopIteration`.
20
+ # According to PEP 479, raising a `StopIteration` or its subclass from within a
21
+ # generator will wrap it in a `RuntimeError`. Since this exception is designed
22
+ # to be raised from a generator-based dataloader and caught by the training loop,
23
+ # inheriting from `StopIteration` would make it uncatchable and would crash the
24
+ # program.
25
+ # See: https://peps.python.org/pep-0479/
26
+ class DataloaderExhaustedError (Exception ):
21
27
"""An exception that indicates dataloader exhaustion."""
22
28
23
29
pass
Original file line number Diff line number Diff line change 13
13
from torch .distributed .elastic .multiprocessing .errors import record
14
14
15
15
import torchtitan .protocols .train_spec as train_spec_module
16
- from torchtitan .components .dataloader import DataloaderStopIteration
16
+ from torchtitan .components .dataloader import DataloaderExhaustedError
17
17
from torchtitan .components .metrics import build_metrics_processor
18
18
from torchtitan .components .tokenizer import build_hf_tokenizer
19
19
from torchtitan .components .validate import build_validator
@@ -135,7 +135,7 @@ def batch_generator(
135
135
except StopIteration as ex :
136
136
# If data runs out during gradient accumulation, that
137
137
# entire step will not be executed.
138
- raise DataloaderStopIteration () from ex
138
+ raise DataloaderExhaustedError () from ex
139
139
data_load_start = time .perf_counter ()
140
140
input_dict , labels = batch
141
141
self .metrics_processor .ntokens_since_last_log += labels .numel ()
@@ -292,7 +292,7 @@ def train(self):
292
292
self .gc_handler .run (self .step )
293
293
try :
294
294
self .train_step (data_iterator )
295
- except DataloaderStopIteration :
295
+ except DataloaderExhaustedError :
296
296
logger .warning ("Ran out of data; last step was canceled." )
297
297
break
298
298
Original file line number Diff line number Diff line change 15
15
16
16
import torchtitan .protocols .train_spec as train_spec_module
17
17
from torchtitan .components .checkpoint import CheckpointManager
18
- from torchtitan .components .dataloader import DataloaderStopIteration
18
+ from torchtitan .components .dataloader import DataloaderExhaustedError
19
19
from torchtitan .components .ft import FTManager , maybe_semi_sync_training
20
20
from torchtitan .components .loss import rescale_accumulated_loss
21
21
from torchtitan .components .metrics import (
@@ -386,7 +386,7 @@ def batch_generator(
386
386
except StopIteration as ex :
387
387
# If data runs out during gradient accumulation, that
388
388
# entire step will not be executed.
389
- raise DataloaderStopIteration () from ex
389
+ raise DataloaderExhaustedError () from ex
390
390
input_dict , labels = batch
391
391
ntokens_batch = labels .numel ()
392
392
self .ntokens_seen += ntokens_batch
@@ -583,7 +583,7 @@ def train(self):
583
583
self .gc_handler .run (self .step )
584
584
try :
585
585
self .train_step (data_iterator )
586
- except DataloaderStopIteration :
586
+ except DataloaderExhaustedError :
587
587
logger .warning ("Ran out of data; last step was canceled." )
588
588
break
589
589
You can’t perform that action at this time.
0 commit comments