Skip to content

Commit e6739ab

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
fix ddp train mode (#794)
Summary: Pull Request resolved: #794 Since D53659696, train mode was only applied to the inner module of DDP. This diff fixes this Reviewed By: diego-urgell Differential Revision: D56424247 fbshipit-source-id: 179e14180cdb8bbc08fde595220bb76f75a37c02
1 parent f02d654 commit e6739ab

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

torchtnt/framework/_loop_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,19 @@ def _set_module_training_mode(
6464
prior_module_train_states = {}
6565
for name, module in modules.items():
6666
prior_module_train_states[name] = module.training
67-
if isinstance(module, DistributedDataParallel):
68-
module = module.module
69-
if torch.ao.quantization.pt2e.export_utils.model_is_exported(module):
67+
is_ddp = isinstance(module, DistributedDataParallel)
68+
69+
if torch.ao.quantization.pt2e.export_utils.model_is_exported(
70+
module.module if is_ddp else module
71+
):
7072
if mode:
71-
module = torch.ao.quantization.move_exported_model_to_train(module)
73+
module = torch.ao.quantization.move_exported_model_to_train(
74+
module.module if is_ddp else module
75+
)
7276
else:
73-
module = torch.ao.quantization.move_exported_model_to_eval(module)
77+
module = torch.ao.quantization.move_exported_model_to_eval(
78+
module.module if is_ddp else module
79+
)
7480
else:
7581
module.train(mode)
7682

0 commit comments

Comments
 (0)