|
14 | 14 | import torch
|
15 | 15 |
|
16 | 16 | from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
|
| 17 | +from torch import nn |
17 | 18 |
|
18 | 19 | from torch.distributed import GradBucket
|
19 | 20 | from torchtnt.framework._test_utils import (
|
|
39 | 40 | from torchtnt.utils.distributed import spawn_multi_process
|
40 | 41 | from torchtnt.utils.env import init_from_env
|
41 | 42 | from torchtnt.utils.lr_scheduler import TLRScheduler
|
42 |
| -from torchtnt.utils.prepare_module import DDPStrategy |
| 43 | +from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy |
43 | 44 | from torchtnt.utils.progress import Progress
|
44 | 45 | from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
|
45 | 46 | from torchtnt.utils.test_utils import skip_if_not_distributed
|
@@ -294,6 +295,95 @@ def test_configure_optimizers_and_lr_scheduler_called_once(self) -> None:
|
294 | 295 | )
|
295 | 296 | self.assertEqual(configure_optimizers_and_lr_scheduler_mock.call_count, 1)
|
296 | 297 |
|
| 298 | + @skip_if_not_distributed |
| 299 | + def test_module_attr_duplicate_reference_validation(self) -> None: |
| 300 | + spawn_multi_process( |
| 301 | + 2, |
| 302 | + "gloo", |
| 303 | + self._test_module_attr_duplicate_reference_validation, |
| 304 | + ) |
| 305 | + |
| 306 | + @staticmethod |
| 307 | + def _test_module_attr_duplicate_reference_validation() -> None: |
| 308 | + error_msg = ( |
| 309 | + "Attribute '{name}' of the custom TNT Unit stores a reference to the model managed" |
| 310 | + "by AutoUnit. This is known to cause errors on checkpointing and model training " |
| 311 | + "mode. Please remove this attribute and access the existing `self.module` instead." |
| 312 | + ) |
| 313 | + |
| 314 | + # Unit that stores unwrapped module |
| 315 | + class ChildUnit(AutoUnit): |
| 316 | + def __init__(self, module, strategy): |
| 317 | + super().__init__(module=module, strategy=strategy) |
| 318 | + self._module = self.module.module if strategy else self.module |
| 319 | + |
| 320 | + def compute_loss( |
| 321 | + self, state: State, data: Batch |
| 322 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 323 | + return torch.Tensor([1]), torch.Tensor([1]) |
| 324 | + |
| 325 | + def configure_optimizers_and_lr_scheduler( |
| 326 | + self, module: torch.nn.Module |
| 327 | + ) -> Tuple[torch.optim.Optimizer, TLRScheduler]: |
| 328 | + return MagicMock(), MagicMock() |
| 329 | + |
| 330 | + # Test with two levels of inheritance |
| 331 | + class GrandchildUnit(DummyAutoUnit): |
| 332 | + def __init__(self, module, strategy): |
| 333 | + super().__init__(module=module, strategy=strategy) |
| 334 | + self._module = module |
| 335 | + |
| 336 | + # Test duplicated references to module |
| 337 | + test_cases = [ |
| 338 | + (DummyAutoUnit, None, False), |
| 339 | + (ChildUnit, None, True), |
| 340 | + (ChildUnit, FSDPStrategy(), True), |
| 341 | + (ChildUnit, DDPStrategy(), True), |
| 342 | + (GrandchildUnit, None, True), |
| 343 | + ] |
| 344 | + for unit_type, strategy, expect_error in test_cases: |
| 345 | + module = nn.Linear(2, 2) |
| 346 | + error_container = [] |
| 347 | + with patch( |
| 348 | + "torchtnt.framework.auto_unit.logging.Logger.error", |
| 349 | + side_effect=error_container.append, |
| 350 | + ): |
| 351 | + unit = unit_type(module=module, strategy=strategy) |
| 352 | + |
| 353 | + tc = unittest.TestCase() |
| 354 | + expected_errors = [error_msg.format(name="_module")] if expect_error else [] |
| 355 | + tc.assertEqual(error_container, expected_errors) |
| 356 | + tc.assertIs(module, unit.module.module if strategy else unit.module) |
| 357 | + |
| 358 | + def test_module_attr_reassignment_validation(self) -> None: |
| 359 | + # Test reassignment of module attribute |
| 360 | + class ReassigningUnit1(DummyAutoUnit): |
| 361 | + def __init__(self, module): |
| 362 | + super().__init__(module=module) |
| 363 | + self.module = module |
| 364 | + |
| 365 | + class ReassigningUnit2(DummyAutoUnit): |
| 366 | + def __init__(self, module): |
| 367 | + super().__init__(module=module) |
| 368 | + self.configure_model() |
| 369 | + |
| 370 | + def configure_model(self): |
| 371 | + self.module = torch.nn.Linear(3, 3) |
| 372 | + |
| 373 | + for unit_type in (ReassigningUnit1, ReassigningUnit2): |
| 374 | + module = nn.Linear(2, 2) |
| 375 | + warning_container = [] |
| 376 | + with patch( |
| 377 | + "torchtnt.framework.auto_unit.logging.Logger.warning", |
| 378 | + side_effect=warning_container.append, |
| 379 | + ): |
| 380 | + unit_type(module=module) |
| 381 | + |
| 382 | + expected_warnings = [ |
| 383 | + "The self.module attribute is managed by AutoUnit and is not meant to be reassigned." |
| 384 | + ] |
| 385 | + self.assertEqual(warning_container, expected_warnings) |
| 386 | + |
297 | 387 | @skip_if_not_distributed
|
298 | 388 | def test_auto_unit_ddp(self) -> None:
|
299 | 389 | """
|
|
0 commit comments