10
10
11
11
import torch
12
12
13
- # TODO: torch/optim/swa_utils.pyi needs to be updated
14
- # pyre-ignore Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
15
13
from torch .optim .swa_utils import (
16
14
AveragedModel as PyTorchAveragedModel ,
17
15
get_ema_multi_avg_fn ,
@@ -56,12 +54,8 @@ def __init__(
56
54
if ema_decay < 0.0 or ema_decay > 1.0 :
57
55
raise ValueError (f"Decay must be between 0 and 1, got { ema_decay } " )
58
56
59
- # TODO: torch/optim/swa_utils.pyi needs to be updated
60
- # pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
61
57
multi_avg_fn = get_ema_multi_avg_fn (ema_decay )
62
58
elif averaging_method == "swa" :
63
- # TODO: torch/optim/swa_utils.pyi needs to be updated
64
- # pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
65
59
multi_avg_fn = get_swa_multi_avg_fn ()
66
60
67
61
if use_lit :
@@ -89,8 +83,6 @@ def __init__(
89
83
else :
90
84
# use default init implementation
91
85
92
- # TODO: torch/optim/swa_utils.pyi needs to be updated
93
- # pyre-ignore Unexpected keyword [28]
94
86
super ().__init__ (
95
87
model ,
96
88
device = device ,
@@ -105,7 +97,5 @@ def update_parameters(self, model: torch.nn.Module) -> None:
105
97
self ._ema_decay , (1 + self ._num_updates ) / (10 + self ._num_updates )
106
98
)
107
99
108
- # TODO: torch/optim/swa_utils.pyi needs to be updated
109
- # pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
110
100
self .multi_avg_fn = get_ema_multi_avg_fn (decay )
111
101
super ().update_parameters (model )
0 commit comments