|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +import unittest |
| 10 | +from unittest.mock import patch |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +from torch.utils.data import DataLoader, Dataset |
| 15 | + |
| 16 | +from torchtnt.framework._test_utils import Batch, DummyTrainUnit, get_dummy_train_state |
| 17 | +from torchtnt.framework.callbacks.empty_dataloader_detector import ( |
| 18 | + EmptyDataloaderDetectorCallback, |
| 19 | +) |
| 20 | +from torchtnt.framework.state import State |
| 21 | +from torchtnt.framework.train import train |
| 22 | +from torchtnt.framework.unit import TrainUnit |
| 23 | + |
| 24 | + |
| 25 | +class MockTrainUnit(DummyTrainUnit): |
| 26 | + """Mock train unit for testing that extends DummyTrainUnit with step control functionality.""" |
| 27 | + |
| 28 | + def __init__(self) -> None: |
| 29 | + super().__init__(input_dim=2) # Use a default input dimension |
| 30 | + self._steps_completed_in_prev_epoch = 0 |
| 31 | + |
| 32 | + def set_steps_completed_in_prev_epoch(self, steps: int) -> None: |
| 33 | + """Set the number of steps completed in the previous epoch.""" |
| 34 | + self._steps_completed_in_prev_epoch = steps |
| 35 | + self.train_progress._num_steps_completed_in_prev_epoch = steps |
| 36 | + |
| 37 | + |
| 38 | +class EmptyDataloaderDetectorCallbackTest(unittest.TestCase): |
| 39 | + def test_init_invalid_threshold(self) -> None: |
| 40 | + """Test that invalid threshold values raise ValueError.""" |
| 41 | + with self.assertRaisesRegex(ValueError, "threshold must be a positive integer"): |
| 42 | + EmptyDataloaderDetectorCallback(threshold=0) |
| 43 | + |
| 44 | + with self.assertRaisesRegex(ValueError, "threshold must be a positive integer"): |
| 45 | + EmptyDataloaderDetectorCallback(threshold=-1) |
| 46 | + |
| 47 | + def test_init_valid_threshold(self) -> None: |
| 48 | + """Test that valid threshold values are accepted.""" |
| 49 | + callback = EmptyDataloaderDetectorCallback(threshold=1) |
| 50 | + self.assertEqual(callback._threshold, 1) |
| 51 | + |
| 52 | + callback = EmptyDataloaderDetectorCallback(threshold=5) |
| 53 | + self.assertEqual(callback._threshold, 5) |
| 54 | + |
| 55 | + def test_train_empty_epoch_detection_with_exception(self) -> None: |
| 56 | + """Test that consecutive empty train epochs trigger exception when threshold is reached.""" |
| 57 | + callback = EmptyDataloaderDetectorCallback(threshold=2) |
| 58 | + state = get_dummy_train_state() |
| 59 | + unit = MockTrainUnit() |
| 60 | + |
| 61 | + # First empty epoch - should not raise |
| 62 | + unit.set_steps_completed_in_prev_epoch(0) |
| 63 | + callback.on_train_epoch_end(state, unit) |
| 64 | + self.assertEqual(callback._consecutive_empty_train_epochs, 1) |
| 65 | + |
| 66 | + # Second empty epoch - should raise exception |
| 67 | + unit.set_steps_completed_in_prev_epoch(0) |
| 68 | + with self.assertRaisesRegex( |
| 69 | + RuntimeError, |
| 70 | + "Detected 2 consecutive empty train epochs, which exceeds the threshold of 2", |
| 71 | + ): |
| 72 | + callback.on_train_epoch_end(state, unit) |
| 73 | + |
| 74 | + def test_train_reset_counter_on_non_empty_epoch(self) -> None: |
| 75 | + """Test that consecutive empty epoch counter resets when a non-empty epoch occurs.""" |
| 76 | + callback = EmptyDataloaderDetectorCallback(threshold=3) |
| 77 | + state = get_dummy_train_state() |
| 78 | + unit = MockTrainUnit() |
| 79 | + |
| 80 | + # First empty epoch |
| 81 | + unit.set_steps_completed_in_prev_epoch(0) |
| 82 | + callback.on_train_epoch_end(state, unit) |
| 83 | + self.assertEqual(callback._consecutive_empty_train_epochs, 1) |
| 84 | + |
| 85 | + # Second empty epoch |
| 86 | + unit.set_steps_completed_in_prev_epoch(0) |
| 87 | + callback.on_train_epoch_end(state, unit) |
| 88 | + self.assertEqual(callback._consecutive_empty_train_epochs, 2) |
| 89 | + |
| 90 | + # Non-empty epoch - should reset counter |
| 91 | + unit.set_steps_completed_in_prev_epoch(5) |
| 92 | + callback.on_train_epoch_end(state, unit) |
| 93 | + self.assertEqual(callback._consecutive_empty_train_epochs, 0) |
| 94 | + |
| 95 | + # Another empty epoch - counter should start from 1 again |
| 96 | + unit.set_steps_completed_in_prev_epoch(0) |
| 97 | + callback.on_train_epoch_end(state, unit) |
| 98 | + self.assertEqual(callback._consecutive_empty_train_epochs, 1) |
| 99 | + |
| 100 | + def test_threshold_one(self) -> None: |
| 101 | + """Test that threshold=1 triggers immediately on first empty epoch.""" |
| 102 | + callback = EmptyDataloaderDetectorCallback(threshold=1) |
| 103 | + state = get_dummy_train_state() |
| 104 | + unit = MockTrainUnit() |
| 105 | + |
| 106 | + # First empty epoch should immediately trigger exception |
| 107 | + unit.set_steps_completed_in_prev_epoch(0) |
| 108 | + with self.assertRaisesRegex( |
| 109 | + RuntimeError, |
| 110 | + "Detected 1 consecutive empty train epochs, which exceeds the threshold of 1", |
| 111 | + ): |
| 112 | + callback.on_train_epoch_end(state, unit) |
| 113 | + |
| 114 | + def test_high_threshold(self) -> None: |
| 115 | + """Test that high threshold values work correctly.""" |
| 116 | + callback = EmptyDataloaderDetectorCallback(threshold=5) |
| 117 | + state = get_dummy_train_state() |
| 118 | + unit = MockTrainUnit() |
| 119 | + |
| 120 | + # Four empty epochs should not trigger |
| 121 | + for i in range(4): |
| 122 | + unit.set_steps_completed_in_prev_epoch(0) |
| 123 | + callback.on_train_epoch_end(state, unit) |
| 124 | + self.assertEqual(callback._consecutive_empty_train_epochs, i + 1) |
| 125 | + |
| 126 | + # Fifth empty epoch should trigger exception |
| 127 | + unit.set_steps_completed_in_prev_epoch(0) |
| 128 | + with self.assertRaisesRegex( |
| 129 | + RuntimeError, |
| 130 | + "Detected 5 consecutive empty train epochs, which exceeds the threshold of 5", |
| 131 | + ): |
| 132 | + callback.on_train_epoch_end(state, unit) |
| 133 | + |
| 134 | + def test_warning_logged_for_each_empty_epoch(self) -> None: |
| 135 | + """Test that a warning is logged for each empty epoch.""" |
| 136 | + callback = EmptyDataloaderDetectorCallback(threshold=3) |
| 137 | + state = get_dummy_train_state() |
| 138 | + unit = MockTrainUnit() |
| 139 | + |
| 140 | + with patch( |
| 141 | + "torchtnt.framework.callbacks.empty_dataloader_detector.logger" |
| 142 | + ) as mock_logger: |
| 143 | + # First empty epoch |
| 144 | + unit.set_steps_completed_in_prev_epoch(0) |
| 145 | + callback.on_train_epoch_end(state, unit) |
| 146 | + |
| 147 | + # Second empty epoch |
| 148 | + unit.set_steps_completed_in_prev_epoch(0) |
| 149 | + callback.on_train_epoch_end(state, unit) |
| 150 | + |
| 151 | + # Verify warnings were logged for each empty epoch |
| 152 | + self.assertEqual(mock_logger.warning.call_count, 2) |
| 153 | + warning_calls = mock_logger.warning.call_args_list |
| 154 | + self.assertTrue( |
| 155 | + any("Empty train epoch detected" in str(call) for call in warning_calls) |
| 156 | + ) |
| 157 | + |
| 158 | + def test_non_empty_epochs_do_not_trigger_warnings(self) -> None: |
| 159 | + """Test that non-empty epochs do not trigger any warnings or exceptions.""" |
| 160 | + callback = EmptyDataloaderDetectorCallback(threshold=2) |
| 161 | + state = get_dummy_train_state() |
| 162 | + unit = MockTrainUnit() |
| 163 | + |
| 164 | + with patch( |
| 165 | + "torchtnt.framework.callbacks.empty_dataloader_detector.logger" |
| 166 | + ) as mock_logger: |
| 167 | + # Multiple non-empty epochs |
| 168 | + for steps in [1, 5, 10, 100]: |
| 169 | + unit.set_steps_completed_in_prev_epoch(steps) |
| 170 | + callback.on_train_epoch_end(state, unit) |
| 171 | + |
| 172 | + # No warnings should be logged |
| 173 | + mock_logger.warning.assert_not_called() |
| 174 | + |
| 175 | + # Counter should remain at 0 |
| 176 | + self.assertEqual(callback._consecutive_empty_train_epochs, 0) |
| 177 | + |
| 178 | + def test_empty_dataloader_detection_with_real_training_loop(self) -> None: |
| 179 | + """ |
| 180 | + Test that simulates the real scenario from failed MAST job f762746046-pviolatingquery_cse. |
| 181 | + Tests EmptyDataloaderDetectorCallback with actual training loop and empty dataloaders. |
| 182 | + """ |
| 183 | + |
| 184 | + class EmptyDataset(Dataset[Batch]): |
| 185 | + """Dataset that returns no data to simulate empty dataloader scenario.""" |
| 186 | + |
| 187 | + def __len__(self) -> int: |
| 188 | + return 0 |
| 189 | + |
| 190 | + def __getitem__(self, idx: int) -> Batch: |
| 191 | + raise IndexError("Empty dataset") |
| 192 | + |
| 193 | + callback_with_exception = EmptyDataloaderDetectorCallback(threshold=2) |
| 194 | + |
| 195 | + train_unit = DummyTrainUnit(input_dim=2) |
| 196 | + empty_dataloader = DataLoader(EmptyDataset(), batch_size=1) |
| 197 | + |
| 198 | + # This should raise an exception after 2 empty epochs |
| 199 | + with self.assertRaisesRegex( |
| 200 | + RuntimeError, |
| 201 | + "Detected 2 consecutive empty train epochs, which exceeds the threshold of 2", |
| 202 | + ): |
| 203 | + train( |
| 204 | + train_unit, |
| 205 | + empty_dataloader, |
| 206 | + max_epochs=50, # Try to run 50 epochs but should fail at 2 |
| 207 | + callbacks=[callback_with_exception], |
| 208 | + ) |
| 209 | + |
| 210 | + self.assertEqual(callback_with_exception._consecutive_empty_train_epochs, 2) |
0 commit comments