Skip to content

Commit b5b0b03

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Detect duplicated references to module instance in AutoUnit (#893)
Summary: Pull Request resolved: #893 Reviewed By: JKSenthil Differential Revision: D62053414 fbshipit-source-id: 93dd9d73807d12707561ed00318adf9a4cdf90af
1 parent 24e6af6 commit b5b0b03

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

tests/framework/test_auto_unit.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
17+
from torch import nn
1718

1819
from torch.distributed import GradBucket
1920
from torchtnt.framework._test_utils import (
@@ -39,7 +40,7 @@
3940
from torchtnt.utils.distributed import spawn_multi_process
4041
from torchtnt.utils.env import init_from_env
4142
from torchtnt.utils.lr_scheduler import TLRScheduler
42-
from torchtnt.utils.prepare_module import DDPStrategy
43+
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy
4344
from torchtnt.utils.progress import Progress
4445
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
4546
from torchtnt.utils.test_utils import skip_if_not_distributed
@@ -294,6 +295,95 @@ def test_configure_optimizers_and_lr_scheduler_called_once(self) -> None:
294295
)
295296
self.assertEqual(configure_optimizers_and_lr_scheduler_mock.call_count, 1)
296297

298+
@skip_if_not_distributed
299+
def test_module_attr_duplicate_reference_validation(self) -> None:
300+
spawn_multi_process(
301+
2,
302+
"gloo",
303+
self._test_module_attr_duplicate_reference_validation,
304+
)
305+
306+
@staticmethod
307+
def _test_module_attr_duplicate_reference_validation() -> None:
308+
error_msg = (
309+
"Attribute '{name}' of the custom TNT Unit stores a reference to the model managed"
310+
"by AutoUnit. This is known to cause errors on checkpointing and model training "
311+
"mode. Please remove this attribute and access the existing `self.module` instead."
312+
)
313+
314+
# Unit that stores unwrapped module
315+
class ChildUnit(AutoUnit):
316+
def __init__(self, module, strategy):
317+
super().__init__(module=module, strategy=strategy)
318+
self._module = self.module.module if strategy else self.module
319+
320+
def compute_loss(
321+
self, state: State, data: Batch
322+
) -> Tuple[torch.Tensor, torch.Tensor]:
323+
return torch.Tensor([1]), torch.Tensor([1])
324+
325+
def configure_optimizers_and_lr_scheduler(
326+
self, module: torch.nn.Module
327+
) -> Tuple[torch.optim.Optimizer, TLRScheduler]:
328+
return MagicMock(), MagicMock()
329+
330+
# Test with two levels of inheritance
331+
class GrandchildUnit(DummyAutoUnit):
332+
def __init__(self, module, strategy):
333+
super().__init__(module=module, strategy=strategy)
334+
self._module = module
335+
336+
# Test duplicated references to module
337+
test_cases = [
338+
(DummyAutoUnit, None, False),
339+
(ChildUnit, None, True),
340+
(ChildUnit, FSDPStrategy(), True),
341+
(ChildUnit, DDPStrategy(), True),
342+
(GrandchildUnit, None, True),
343+
]
344+
for unit_type, strategy, expect_error in test_cases:
345+
module = nn.Linear(2, 2)
346+
error_container = []
347+
with patch(
348+
"torchtnt.framework.auto_unit.logging.Logger.error",
349+
side_effect=error_container.append,
350+
):
351+
unit = unit_type(module=module, strategy=strategy)
352+
353+
tc = unittest.TestCase()
354+
expected_errors = [error_msg.format(name="_module")] if expect_error else []
355+
tc.assertEqual(error_container, expected_errors)
356+
tc.assertIs(module, unit.module.module if strategy else unit.module)
357+
358+
def test_module_attr_reassignment_validation(self) -> None:
359+
# Test reassignment of module attribute
360+
class ReassigningUnit1(DummyAutoUnit):
361+
def __init__(self, module):
362+
super().__init__(module=module)
363+
self.module = module
364+
365+
class ReassigningUnit2(DummyAutoUnit):
366+
def __init__(self, module):
367+
super().__init__(module=module)
368+
self.configure_model()
369+
370+
def configure_model(self):
371+
self.module = torch.nn.Linear(3, 3)
372+
373+
for unit_type in (ReassigningUnit1, ReassigningUnit2):
374+
module = nn.Linear(2, 2)
375+
warning_container = []
376+
with patch(
377+
"torchtnt.framework.auto_unit.logging.Logger.warning",
378+
side_effect=warning_container.append,
379+
):
380+
unit_type(module=module)
381+
382+
expected_warnings = [
383+
"The self.module attribute is managed by AutoUnit and is not meant to be reassigned."
384+
]
385+
self.assertEqual(warning_container, expected_warnings)
386+
297387
@skip_if_not_distributed
298388
def test_auto_unit_ddp(self) -> None:
299389
"""

torchtnt/framework/auto_unit.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import contextlib
11+
import logging
1112
from abc import ABCMeta, abstractmethod
1213
from copy import deepcopy
1314
from dataclasses import dataclass
@@ -52,6 +53,8 @@
5253
from torchtnt.utils.swa import AveragedModel
5354
from typing_extensions import Literal
5455

56+
_logger: logging.Logger = logging.getLogger(__name__)
57+
5558

5659
TData = TypeVar("TData")
5760

@@ -550,6 +553,43 @@ def __init__(
550553
self.lr_scheduler: Optional[TLRScheduler] = None
551554
self.swa_scheduler: Optional[SWALR] = None
552555

556+
def __setattr__(self, name: str, value: object) -> None:
557+
if isinstance(value, torch.nn.Module):
558+
self._validate_module_attr(name, value)
559+
560+
super().__setattr__(name, value)
561+
562+
def _validate_module_attr(self, name: str, module: torch.nn.Module) -> None:
563+
"""
564+
The AutoUnit is designed to manage the input model using the `self.module` attribute,
565+
which should not be reassigned. Additionally, if a subclass saves another attribute
566+
referencing the same model instance (wrapped or unwrapped), then the same instance will
567+
appear two times in the tracked_modules. This is problematic for checkpointing and handling
568+
of evaluation/training mode.
569+
"""
570+
# First time the module attribute is set is in the AutoUnit's initialization
571+
if not hasattr(self, "module"):
572+
return
573+
574+
# Value of self.module should not be changed after initialization
575+
if name == "module":
576+
_logger.warning(
577+
"The self.module attribute is managed by AutoUnit and is not meant to be reassigned."
578+
)
579+
return
580+
581+
# Otherwise, double check that this is not a duplicate reference to the self.module instance
582+
managed_modules = [self.module]
583+
if isinstance(self.module, DDP) or isinstance(self.module, FSDP):
584+
managed_modules.append(self.module.module)
585+
586+
if any(module is managed_module for managed_module in managed_modules):
587+
_logger.error(
588+
f"Attribute '{name}' of the custom TNT Unit stores a reference to the model managed"
589+
+ "by AutoUnit. This is known to cause errors on checkpointing and model training "
590+
+ "mode. Please remove this attribute and access the existing `self.module` instead."
591+
)
592+
553593
@abstractmethod
554594
def configure_optimizers_and_lr_scheduler(
555595
self, module: torch.nn.Module

0 commit comments

Comments
 (0)