Skip to content

Commit 5beb537

Browse files
galrotemfacebook-github-bot
authored andcommitted
fix pyre errors (#784)
Summary: Pull Request resolved: #784 Reviewed By: anshulverma Differential Revision: D56144363 fbshipit-source-id: 052efee08a0528416477b132990c2453450d2f35
1 parent 05d1458 commit 5beb537

File tree

1 file changed

+0
-10
lines changed

1 file changed

+0
-10
lines changed

torchtnt/utils/swa.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import torch
1212

13-
# TODO: torch/optim/swa_utils.pyi needs to be updated
14-
# pyre-ignore Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
1513
from torch.optim.swa_utils import (
1614
AveragedModel as PyTorchAveragedModel,
1715
get_ema_multi_avg_fn,
@@ -56,12 +54,8 @@ def __init__(
5654
if ema_decay < 0.0 or ema_decay > 1.0:
5755
raise ValueError(f"Decay must be between 0 and 1, got {ema_decay}")
5856

59-
# TODO: torch/optim/swa_utils.pyi needs to be updated
60-
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
6157
multi_avg_fn = get_ema_multi_avg_fn(ema_decay)
6258
elif averaging_method == "swa":
63-
# TODO: torch/optim/swa_utils.pyi needs to be updated
64-
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
6559
multi_avg_fn = get_swa_multi_avg_fn()
6660

6761
if use_lit:
@@ -89,8 +83,6 @@ def __init__(
8983
else:
9084
# use default init implementation
9185

92-
# TODO: torch/optim/swa_utils.pyi needs to be updated
93-
# pyre-ignore Unexpected keyword [28]
9486
super().__init__(
9587
model,
9688
device=device,
@@ -105,7 +97,5 @@ def update_parameters(self, model: torch.nn.Module) -> None:
10597
self._ema_decay, (1 + self._num_updates) / (10 + self._num_updates)
10698
)
10799

108-
# TODO: torch/optim/swa_utils.pyi needs to be updated
109-
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
110100
self.multi_avg_fn = get_ema_multi_avg_fn(decay)
111101
super().update_parameters(model)

0 commit comments

Comments
 (0)