Skip to content

Commit 9197908

Browse files
authored
fix(dataloader): Prevent RuntimeError from DataloaderStopIteration (#1627)
The DataloaderStopIteration exception inherited from StopIteration. According to PEP 479, raising a StopIteration subclass from a generator causes a RuntimeError in Python 3.7+. This change modifies the base class to `Exception` to ensure it can be caught correctly by user code without triggering this behavior. Fixes ISSUE #1626
1 parent 2025abb commit 9197908

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

torchtitan/components/dataloader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616
from torchdata.stateful_dataloader import StatefulDataLoader
1717
from torchtitan.tools.logging import logger
1818

19-
20-
class DataloaderStopIteration(StopIteration):
19+
# NOTE: This class deliberately inherits from `Exception` and not `StopIteration`.
20+
# According to PEP 479, raising a `StopIteration` or its subclass from within a
21+
# generator will wrap it in a `RuntimeError`. Since this exception is designed
22+
# to be raised from a generator-based dataloader and caught by the training loop,
23+
# inheriting from `StopIteration` would make it uncatchable and would crash the
24+
# program.
25+
# See: https://peps.python.org/pep-0479/
26+
class DataloaderExhaustedError(Exception):
2127
"""An exception that indicates dataloader exhaustion."""
2228

2329
pass

torchtitan/experiments/forge/example_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.distributed.elastic.multiprocessing.errors import record
1414

1515
import torchtitan.protocols.train_spec as train_spec_module
16-
from torchtitan.components.dataloader import DataloaderStopIteration
16+
from torchtitan.components.dataloader import DataloaderExhaustedError
1717
from torchtitan.components.metrics import build_metrics_processor
1818
from torchtitan.components.tokenizer import build_hf_tokenizer
1919
from torchtitan.components.validate import build_validator
@@ -135,7 +135,7 @@ def batch_generator(
135135
except StopIteration as ex:
136136
# If data runs out during gradient accumulation, that
137137
# entire step will not be executed.
138-
raise DataloaderStopIteration() from ex
138+
raise DataloaderExhaustedError() from ex
139139
data_load_start = time.perf_counter()
140140
input_dict, labels = batch
141141
self.metrics_processor.ntokens_since_last_log += labels.numel()
@@ -292,7 +292,7 @@ def train(self):
292292
self.gc_handler.run(self.step)
293293
try:
294294
self.train_step(data_iterator)
295-
except DataloaderStopIteration:
295+
except DataloaderExhaustedError:
296296
logger.warning("Ran out of data; last step was canceled.")
297297
break
298298

torchtitan/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torchtitan.protocols.train_spec as train_spec_module
1717
from torchtitan.components.checkpoint import CheckpointManager
18-
from torchtitan.components.dataloader import DataloaderStopIteration
18+
from torchtitan.components.dataloader import DataloaderExhaustedError
1919
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
2020
from torchtitan.components.loss import rescale_accumulated_loss
2121
from torchtitan.components.metrics import (
@@ -386,7 +386,7 @@ def batch_generator(
386386
except StopIteration as ex:
387387
# If data runs out during gradient accumulation, that
388388
# entire step will not be executed.
389-
raise DataloaderStopIteration() from ex
389+
raise DataloaderExhaustedError() from ex
390390
input_dict, labels = batch
391391
ntokens_batch = labels.numel()
392392
self.ntokens_seen += ntokens_batch
@@ -583,7 +583,7 @@ def train(self):
583583
self.gc_handler.run(self.step)
584584
try:
585585
self.train_step(data_iterator)
586-
except DataloaderStopIteration:
586+
except DataloaderExhaustedError:
587587
logger.warning("Ran out of data; last step was canceled.")
588588
break
589589

0 commit comments

Comments
 (0)