Skip to content

Commit b285385

Browse files
supergeorge23facebook-github-bot
authored andcommitted
Create Callback to Identify Empty Batches (#1020)
Summary: Pull Request resolved: #1020 # This Diff: This diff implements Step 1 of T232701473 by creating EmptyDataloaderDetectorCallback, a TNT callback that detects consecutive empty training epochs and implements a fail-fast strategy to surface dataloader issues early. # Callback Feature: The callback helps identify cases where dataloaders return empty batches, which can cause confusing downstream issues that manifest as red herrings (e.g., apparent checkpointing errors that are actually rapid step progression due to empty data). # Next Diff: Add to Mitra's default callbacks (Step 2 of T232701473), and will enable e2e test with Mitra Reviewed By: diego-urgell Differential Revision: D79212756 fbshipit-source-id: ee40e4b8a60225e0b50a3e5bfd60c493774ee314
1 parent 9591481 commit b285385

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
from unittest.mock import patch
11+
12+
import torch
13+
import torch.nn as nn
14+
from torch.utils.data import DataLoader, Dataset
15+
16+
from torchtnt.framework._test_utils import Batch, DummyTrainUnit, get_dummy_train_state
17+
from torchtnt.framework.callbacks.empty_dataloader_detector import (
18+
EmptyDataloaderDetectorCallback,
19+
)
20+
from torchtnt.framework.state import State
21+
from torchtnt.framework.train import train
22+
from torchtnt.framework.unit import TrainUnit
23+
24+
25+
class MockTrainUnit(DummyTrainUnit):
26+
"""Mock train unit for testing that extends DummyTrainUnit with step control functionality."""
27+
28+
def __init__(self) -> None:
29+
super().__init__(input_dim=2) # Use a default input dimension
30+
self._steps_completed_in_prev_epoch = 0
31+
32+
def set_steps_completed_in_prev_epoch(self, steps: int) -> None:
33+
"""Set the number of steps completed in the previous epoch."""
34+
self._steps_completed_in_prev_epoch = steps
35+
self.train_progress._num_steps_completed_in_prev_epoch = steps
36+
37+
38+
class EmptyDataloaderDetectorCallbackTest(unittest.TestCase):
39+
def test_init_invalid_threshold(self) -> None:
40+
"""Test that invalid threshold values raise ValueError."""
41+
with self.assertRaisesRegex(ValueError, "threshold must be a positive integer"):
42+
EmptyDataloaderDetectorCallback(threshold=0)
43+
44+
with self.assertRaisesRegex(ValueError, "threshold must be a positive integer"):
45+
EmptyDataloaderDetectorCallback(threshold=-1)
46+
47+
def test_init_valid_threshold(self) -> None:
48+
"""Test that valid threshold values are accepted."""
49+
callback = EmptyDataloaderDetectorCallback(threshold=1)
50+
self.assertEqual(callback._threshold, 1)
51+
52+
callback = EmptyDataloaderDetectorCallback(threshold=5)
53+
self.assertEqual(callback._threshold, 5)
54+
55+
def test_train_empty_epoch_detection_with_exception(self) -> None:
56+
"""Test that consecutive empty train epochs trigger exception when threshold is reached."""
57+
callback = EmptyDataloaderDetectorCallback(threshold=2)
58+
state = get_dummy_train_state()
59+
unit = MockTrainUnit()
60+
61+
# First empty epoch - should not raise
62+
unit.set_steps_completed_in_prev_epoch(0)
63+
callback.on_train_epoch_end(state, unit)
64+
self.assertEqual(callback._consecutive_empty_train_epochs, 1)
65+
66+
# Second empty epoch - should raise exception
67+
unit.set_steps_completed_in_prev_epoch(0)
68+
with self.assertRaisesRegex(
69+
RuntimeError,
70+
"Detected 2 consecutive empty train epochs, which exceeds the threshold of 2",
71+
):
72+
callback.on_train_epoch_end(state, unit)
73+
74+
def test_train_reset_counter_on_non_empty_epoch(self) -> None:
75+
"""Test that consecutive empty epoch counter resets when a non-empty epoch occurs."""
76+
callback = EmptyDataloaderDetectorCallback(threshold=3)
77+
state = get_dummy_train_state()
78+
unit = MockTrainUnit()
79+
80+
# First empty epoch
81+
unit.set_steps_completed_in_prev_epoch(0)
82+
callback.on_train_epoch_end(state, unit)
83+
self.assertEqual(callback._consecutive_empty_train_epochs, 1)
84+
85+
# Second empty epoch
86+
unit.set_steps_completed_in_prev_epoch(0)
87+
callback.on_train_epoch_end(state, unit)
88+
self.assertEqual(callback._consecutive_empty_train_epochs, 2)
89+
90+
# Non-empty epoch - should reset counter
91+
unit.set_steps_completed_in_prev_epoch(5)
92+
callback.on_train_epoch_end(state, unit)
93+
self.assertEqual(callback._consecutive_empty_train_epochs, 0)
94+
95+
# Another empty epoch - counter should start from 1 again
96+
unit.set_steps_completed_in_prev_epoch(0)
97+
callback.on_train_epoch_end(state, unit)
98+
self.assertEqual(callback._consecutive_empty_train_epochs, 1)
99+
100+
def test_threshold_one(self) -> None:
101+
"""Test that threshold=1 triggers immediately on first empty epoch."""
102+
callback = EmptyDataloaderDetectorCallback(threshold=1)
103+
state = get_dummy_train_state()
104+
unit = MockTrainUnit()
105+
106+
# First empty epoch should immediately trigger exception
107+
unit.set_steps_completed_in_prev_epoch(0)
108+
with self.assertRaisesRegex(
109+
RuntimeError,
110+
"Detected 1 consecutive empty train epochs, which exceeds the threshold of 1",
111+
):
112+
callback.on_train_epoch_end(state, unit)
113+
114+
def test_high_threshold(self) -> None:
115+
"""Test that high threshold values work correctly."""
116+
callback = EmptyDataloaderDetectorCallback(threshold=5)
117+
state = get_dummy_train_state()
118+
unit = MockTrainUnit()
119+
120+
# Four empty epochs should not trigger
121+
for i in range(4):
122+
unit.set_steps_completed_in_prev_epoch(0)
123+
callback.on_train_epoch_end(state, unit)
124+
self.assertEqual(callback._consecutive_empty_train_epochs, i + 1)
125+
126+
# Fifth empty epoch should trigger exception
127+
unit.set_steps_completed_in_prev_epoch(0)
128+
with self.assertRaisesRegex(
129+
RuntimeError,
130+
"Detected 5 consecutive empty train epochs, which exceeds the threshold of 5",
131+
):
132+
callback.on_train_epoch_end(state, unit)
133+
134+
def test_warning_logged_for_each_empty_epoch(self) -> None:
135+
"""Test that a warning is logged for each empty epoch."""
136+
callback = EmptyDataloaderDetectorCallback(threshold=3)
137+
state = get_dummy_train_state()
138+
unit = MockTrainUnit()
139+
140+
with patch(
141+
"torchtnt.framework.callbacks.empty_dataloader_detector.logger"
142+
) as mock_logger:
143+
# First empty epoch
144+
unit.set_steps_completed_in_prev_epoch(0)
145+
callback.on_train_epoch_end(state, unit)
146+
147+
# Second empty epoch
148+
unit.set_steps_completed_in_prev_epoch(0)
149+
callback.on_train_epoch_end(state, unit)
150+
151+
# Verify warnings were logged for each empty epoch
152+
self.assertEqual(mock_logger.warning.call_count, 2)
153+
warning_calls = mock_logger.warning.call_args_list
154+
self.assertTrue(
155+
any("Empty train epoch detected" in str(call) for call in warning_calls)
156+
)
157+
158+
def test_non_empty_epochs_do_not_trigger_warnings(self) -> None:
159+
"""Test that non-empty epochs do not trigger any warnings or exceptions."""
160+
callback = EmptyDataloaderDetectorCallback(threshold=2)
161+
state = get_dummy_train_state()
162+
unit = MockTrainUnit()
163+
164+
with patch(
165+
"torchtnt.framework.callbacks.empty_dataloader_detector.logger"
166+
) as mock_logger:
167+
# Multiple non-empty epochs
168+
for steps in [1, 5, 10, 100]:
169+
unit.set_steps_completed_in_prev_epoch(steps)
170+
callback.on_train_epoch_end(state, unit)
171+
172+
# No warnings should be logged
173+
mock_logger.warning.assert_not_called()
174+
175+
# Counter should remain at 0
176+
self.assertEqual(callback._consecutive_empty_train_epochs, 0)
177+
178+
def test_empty_dataloader_detection_with_real_training_loop(self) -> None:
179+
"""
180+
Test that simulates the real scenario from failed MAST job f762746046-pviolatingquery_cse.
181+
Tests EmptyDataloaderDetectorCallback with actual training loop and empty dataloaders.
182+
"""
183+
184+
class EmptyDataset(Dataset[Batch]):
185+
"""Dataset that returns no data to simulate empty dataloader scenario."""
186+
187+
def __len__(self) -> int:
188+
return 0
189+
190+
def __getitem__(self, idx: int) -> Batch:
191+
raise IndexError("Empty dataset")
192+
193+
callback_with_exception = EmptyDataloaderDetectorCallback(threshold=2)
194+
195+
train_unit = DummyTrainUnit(input_dim=2)
196+
empty_dataloader = DataLoader(EmptyDataset(), batch_size=1)
197+
198+
# This should raise an exception after 2 empty epochs
199+
with self.assertRaisesRegex(
200+
RuntimeError,
201+
"Detected 2 consecutive empty train epochs, which exceeds the threshold of 2",
202+
):
203+
train(
204+
train_unit,
205+
empty_dataloader,
206+
max_epochs=50, # Try to run 50 epochs but should fail at 2
207+
callbacks=[callback_with_exception],
208+
)
209+
210+
self.assertEqual(callback_with_exception._consecutive_empty_train_epochs, 2)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
11+
from torchtnt.framework.callback import Callback
12+
from torchtnt.framework.state import State
13+
from torchtnt.framework.unit import TTrainUnit
14+
15+
logger: logging.Logger = logging.getLogger(__name__)
16+
17+
18+
class EmptyDataloaderDetectorCallback(Callback):
19+
"""
20+
A callback that detects consecutive empty epochs and raises an error when a threshold is reached.
21+
22+
This callback helps identify issues where dataloaders return empty batches, which can cause confusing
23+
downstream problems that are hard to debug. It implements a fail-fast strategy to surface these issues early.
24+
"""
25+
26+
def __init__(
27+
self,
28+
threshold: int = 2,
29+
) -> None:
30+
if threshold <= 0:
31+
raise ValueError("threshold must be a positive integer")
32+
33+
self._threshold = threshold
34+
self._consecutive_empty_train_epochs = 0
35+
36+
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
37+
num_steps = unit.train_progress.num_steps_completed_in_prev_epoch
38+
epoch_num = unit.train_progress.num_epochs_completed
39+
40+
if num_steps == 0:
41+
self._consecutive_empty_train_epochs += 1
42+
logger.warning(
43+
f"Empty train epoch detected! Epoch {epoch_num} completed 0 steps. "
44+
f"Consecutive empty train epochs: {self._consecutive_empty_train_epochs}"
45+
)
46+
47+
if self._consecutive_empty_train_epochs >= self._threshold:
48+
error_msg = (
49+
f"Detected {self._consecutive_empty_train_epochs} consecutive empty train epochs, "
50+
f"which exceeds the threshold of {self._threshold}. This indicates that the "
51+
f"dataloader is returning empty batches, which could be due to an empty "
52+
f"training table or infrastructure issues with the dataloader."
53+
)
54+
raise RuntimeError(error_msg)
55+
else:
56+
self._consecutive_empty_train_epochs = 0

0 commit comments

Comments
 (0)