Skip to content

Commit dc5bafd

Browse files
bluenote10facebook-github-bot
authored andcommitted
Fix type-safety of torch.nn.Module instances
Summary: X-link: meta-recsys/generative-recommenders#129 X-link: pytorch/FBGEMM#3387 X-link: facebookresearch/FBGEMM#476 X-link: pytorch/torchrec#2562 As laid out in pytorch/pytorch#81462 (comment) the change in pytorch/pytorch#104321 was not necessary and largely destroys the type-safety of `torch.nn.Module` instances. As far as I can see, the underlying issue of pytorch/pytorch#81462 in `torch.nn.parallel.DistributedDataParallel` has been fixed in the meantime by actually typing `register_comm_hook` correctly. The proper solution to issues like pytorch/pytorch#81462 is to give the underlying field/method a proper type annotation, then there should be no need to go for a "type system disabling `__getattr__`". (I'll probably be offline for a while, not able to react here...) cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 XilunWu rec mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse tianyu-l kiukchung lucasllc Original PR: pytorch/pytorch#115074 Updated testing PR: pytorch/pytorch#141240 Reviewed By: malfet, aorenste, gineshidalgo99, larryliu0820 Differential Revision: D52890934 Pulled By: ezyang fbshipit-source-id: 23af4111a80b471d810e0bf828f4d49a19b4ba80
1 parent 14ebfea commit dc5bafd

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

tests/utils/test_distributed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,18 @@ def test_revert_sync_batchnorm(self) -> None:
187187
self.assertNotIsInstance(batch_norm, torch.nn.SyncBatchNorm)
188188
self.assertTrue(
189189
torch.equal(
190-
batch_norm.running_mean, none_throws(original_batchnorm.running_mean)
190+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
191+
# `Union[Tensor, Module]`.
192+
batch_norm.running_mean,
193+
none_throws(original_batchnorm.running_mean),
191194
)
192195
)
193196
self.assertTrue(
194197
torch.equal(
195-
batch_norm.running_var, none_throws(original_batchnorm.running_var)
198+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
199+
# `Union[Tensor, Module]`.
200+
batch_norm.running_var,
201+
none_throws(original_batchnorm.running_var),
196202
)
197203
)
198204

torchtnt/framework/_loop_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,23 @@ def _set_module_training_mode(
9494
is_ddp = isinstance(module, DistributedDataParallel)
9595

9696
if _EXPORT_UTILS_AVAIL and model_is_exported(
97-
module.module if is_ddp else module
97+
# pyre-fixme[6]: For 1st argument expected `Module` but got
98+
# `Union[Module, Tensor]`.
99+
module.module
100+
if is_ddp
101+
else module
98102
):
99103
move_fn = (
100104
torch.ao.quantization.move_exported_model_to_train
101105
if mode
102106
else torch.ao.quantization.move_exported_model_to_eval
103107
)
108+
# pyre-fixme[6]: For 1st argument expected `GraphModule` but got
109+
# `Union[Module, Tensor]`.
104110
move_fn(module.module if is_ddp else module)
105111
module.training = mode
106112
if is_ddp:
113+
# pyre-fixme[16]: `Tensor` has no attribute `training`.
107114
module.module.training = mode
108115
else:
109116
module.train(mode)
@@ -122,16 +129,23 @@ def _reset_module_training_mode(
122129
is_ddp = isinstance(module, DistributedDataParallel)
123130

124131
if _EXPORT_UTILS_AVAIL and model_is_exported(
125-
module.module if is_ddp else module
132+
# pyre-fixme[6]: For 1st argument expected `Module` but got
133+
# `Union[Module, Tensor]`.
134+
module.module
135+
if is_ddp
136+
else module
126137
):
127138
move_fn = (
128139
torch.ao.quantization.move_exported_model_to_train
129140
if prior_modes[name]
130141
else torch.ao.quantization.move_exported_model_to_eval
131142
)
143+
# pyre-fixme[6]: For 1st argument expected `GraphModule` but got
144+
# `Union[Module, Tensor]`.
132145
move_fn(module.module if is_ddp else module)
133146
module.training = prior_modes[name]
134147
if is_ddp:
148+
# pyre-fixme[16]: `Tensor` has no attribute `training`.
135149
module.module.training = prior_modes[name]
136150
else:
137151
module.train(prior_modes[name])

torchtnt/framework/auto_unit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
638638
# https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
639639
# https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync
640640
maybe_no_sync = (
641+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
641642
module.no_sync()
642643
if not should_update_weights
643644
and (isinstance(module, DDP) or _is_fsdp_module(module))

torchtnt/utils/distributed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def revert_sync_batchnorm(
436436
module_output.running_var = module.running_var
437437
module_output.num_batches_tracked = module.num_batches_tracked
438438
if hasattr(module, "qconfig"):
439+
# pyre-fixme[16]: `_BatchNormXd` has no attribute `qconfig`.
439440
module_output.qconfig = module.qconfig
440441
for name, child in module.named_children():
441442
module_output.add_module(name, revert_sync_batchnorm(child, device))

0 commit comments

Comments
 (0)