Skip to content

Commit 24e6af6

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
QAT support in core loop (#892)
Summary: Pull Request resolved: #892 Reviewed By: ywwwer, diego-urgell Differential Revision: D61935530 fbshipit-source-id: 6c85ffdccf3a1014441e4bdfc1b527769a44fae9
1 parent 1545b34 commit 24e6af6

File tree

2 files changed

+67
-9
lines changed

2 files changed

+67
-9
lines changed

tests/framework/test_loop_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
from torch import distributed as dist, nn
15+
from torch.ao.quantization.pt2e.export_utils import model_is_exported
1516
from torch.distributed import launcher
1617

1718
from torch.utils.data import DataLoader
@@ -88,6 +89,47 @@ def test_set_module_training_mode(self) -> None:
8889
self.assertFalse(prior_module_train_states["module"])
8990
self.assertFalse(prior_module_train_states["loss_fn"])
9091

92+
def test_set_module_training_mode_qat(self) -> None:
93+
"""
94+
Test _set_module_training_mode
95+
"""
96+
97+
# define a floating point model
98+
class M(torch.nn.Module):
99+
def __init__(self):
100+
super().__init__()
101+
self.fc = torch.nn.Linear(4, 4)
102+
103+
def forward(self, x):
104+
x = self.fc(x)
105+
return x
106+
107+
loss_fn = nn.CrossEntropyLoss()
108+
module = torch.export.export(M(), (torch.rand(4, 4),)).module()
109+
110+
tracked_modules: Dict[str, torch.nn.Module] = {
111+
"module": module,
112+
"loss_fn": loss_fn,
113+
}
114+
115+
self.assertTrue(model_is_exported(module))
116+
prior_module_train_states = _set_module_training_mode(tracked_modules, False)
117+
118+
self.assertFalse(module.training)
119+
self.assertFalse(loss_fn.training)
120+
121+
self.assertTrue(prior_module_train_states["module"])
122+
self.assertTrue(prior_module_train_states["loss_fn"])
123+
124+
# set back to True
125+
prior_module_train_states = _set_module_training_mode(tracked_modules, True)
126+
127+
self.assertTrue(module.training)
128+
self.assertTrue(loss_fn.training)
129+
130+
self.assertFalse(prior_module_train_states["module"])
131+
self.assertFalse(prior_module_train_states["loss_fn"])
132+
91133
def test_reset_module_training_mode(self) -> None:
92134
"""
93135
Test _reset_module_training_mode

torchtnt/framework/_loop_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ def _set_module_training_mode(
9696
if _EXPORT_UTILS_AVAIL and model_is_exported(
9797
module.module if is_ddp else module
9898
):
99-
if mode:
100-
module = torch.ao.quantization.move_exported_model_to_train(
101-
module.module if is_ddp else module
102-
)
103-
else:
104-
module = torch.ao.quantization.move_exported_model_to_eval(
105-
module.module if is_ddp else module
106-
)
99+
move_fn = (
100+
torch.ao.quantization.move_exported_model_to_train
101+
if mode
102+
else torch.ao.quantization.move_exported_model_to_eval
103+
)
104+
move_fn(module.module if is_ddp else module)
105+
module.training = mode
106+
if is_ddp:
107+
module.module.training = mode
107108
else:
108109
module.train(mode)
109110

@@ -118,7 +119,22 @@ def _reset_module_training_mode(
118119
# returning back to the user
119120
for name, module in modules.items():
120121
if name in prior_modes:
121-
module.train(prior_modes[name])
122+
is_ddp = isinstance(module, DistributedDataParallel)
123+
124+
if _EXPORT_UTILS_AVAIL and model_is_exported(
125+
module.module if is_ddp else module
126+
):
127+
move_fn = (
128+
torch.ao.quantization.move_exported_model_to_train
129+
if prior_modes[name]
130+
else torch.ao.quantization.move_exported_model_to_eval
131+
)
132+
move_fn(module.module if is_ddp else module)
133+
module.training = prior_modes[name]
134+
if is_ddp:
135+
module.module.training = prior_modes[name]
136+
else:
137+
module.train(prior_modes[name])
122138

123139

124140
def _log_api_usage(entry_point: str) -> None:

0 commit comments

Comments
 (0)