Skip to content

Commit aff7603

Browse files
yangaaapaulfacebook-github-bot
authored andcommitted
PYRE_FIXME in classification_models.py (meta-pytorch#1636)
Summary: Pull Request resolved: meta-pytorch#1636 Fixed all pyre-fixmes in the file Reviewed By: jjuncho Differential Revision: D80271936 fbshipit-source-id: d7044cc3f76c810cf3dbb5b0dd257e7e22f81219
1 parent 19971fb commit aff7603

File tree

1 file changed

+10
-32
lines changed

1 file changed

+10
-32
lines changed

captum/testing/helpers/classification_models.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,17 @@ class SigmoidModel(nn.Module):
1212
-pytorch-and-make-your-life-simpler-ec5367895199
1313
"""
1414

15-
# pyre-fixme[2]: Parameter must be annotated.
16-
def __init__(self, num_in, num_hidden, num_out) -> None:
15+
def __init__(self, num_in: int, num_hidden: int, num_out: int) -> None:
1716
super().__init__()
18-
# pyre-fixme[4]: Attribute must be annotated.
1917
self.num_in = num_in
20-
# pyre-fixme[4]: Attribute must be annotated.
2118
self.num_hidden = num_hidden
22-
# pyre-fixme[4]: Attribute must be annotated.
2319
self.num_out = num_out
2420
self.lin1 = nn.Linear(num_in, num_hidden)
2521
self.lin2 = nn.Linear(num_hidden, num_out)
2622
self.relu1 = nn.ReLU()
2723
self.sigmoid = nn.Sigmoid()
2824

29-
# pyre-fixme[3]: Return type must be annotated.
30-
# pyre-fixme[2]: Parameter must be annotated.
31-
def forward(self, input):
25+
def forward(self, input: torch.Tensor) -> torch.Tensor:
3226
lin1 = self.lin1(input)
3327
lin2 = self.lin2(self.relu1(lin1))
3428
return self.sigmoid(lin2)
@@ -40,14 +34,12 @@ class SoftmaxModel(nn.Module):
4034
https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/
4135
"""
4236

43-
# pyre-fixme[2]: Parameter must be annotated.
44-
def __init__(self, num_in, num_hidden, num_out, inplace: bool = False) -> None:
37+
def __init__(
38+
self, num_in: int, num_hidden: int, num_out: int, inplace: bool = False
39+
) -> None:
4540
super().__init__()
46-
# pyre-fixme[4]: Attribute must be annotated.
4741
self.num_in = num_in
48-
# pyre-fixme[4]: Attribute must be annotated.
4942
self.num_hidden = num_hidden
50-
# pyre-fixme[4]: Attribute must be annotated.
5143
self.num_out = num_out
5244
self.lin1 = nn.Linear(num_in, num_hidden)
5345
self.lin2 = nn.Linear(num_hidden, num_hidden)
@@ -56,9 +48,7 @@ def __init__(self, num_in, num_hidden, num_out, inplace: bool = False) -> None:
5648
self.relu2 = nn.ReLU(inplace=inplace)
5749
self.softmax = nn.Softmax(dim=1)
5850

59-
# pyre-fixme[3]: Return type must be annotated.
60-
# pyre-fixme[2]: Parameter must be annotated.
61-
def forward(self, input):
51+
def forward(self, input: torch.Tensor) -> torch.Tensor:
6252
lin1 = self.relu1(self.lin1(input))
6353
lin2 = self.relu2(self.lin2(lin1))
6454
lin3 = self.lin3(lin2)
@@ -72,14 +62,10 @@ class SigmoidDeepLiftModel(nn.Module):
7262
-pytorch-and-make-your-life-simpler-ec5367895199
7363
"""
7464

75-
# pyre-fixme[2]: Parameter must be annotated.
76-
def __init__(self, num_in, num_hidden, num_out) -> None:
65+
def __init__(self, num_in: int, num_hidden: int, num_out: int) -> None:
7766
super().__init__()
78-
# pyre-fixme[4]: Attribute must be annotated.
7967
self.num_in = num_in
80-
# pyre-fixme[4]: Attribute must be annotated.
8168
self.num_hidden = num_hidden
82-
# pyre-fixme[4]: Attribute must be annotated.
8369
self.num_out = num_out
8470
self.lin1 = nn.Linear(num_in, num_hidden, bias=False)
8571
self.lin2 = nn.Linear(num_hidden, num_out, bias=False)
@@ -88,9 +74,7 @@ def __init__(self, num_in, num_hidden, num_out) -> None:
8874
self.relu1 = nn.ReLU()
8975
self.sigmoid = nn.Sigmoid()
9076

91-
# pyre-fixme[3]: Return type must be annotated.
92-
# pyre-fixme[2]: Parameter must be annotated.
93-
def forward(self, input):
77+
def forward(self, input: torch.Tensor) -> torch.Tensor:
9478
lin1 = self.lin1(input)
9579
lin2 = self.lin2(self.relu1(lin1))
9680
return self.sigmoid(lin2)
@@ -102,14 +86,10 @@ class SoftmaxDeepLiftModel(nn.Module):
10286
https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/
10387
"""
10488

105-
# pyre-fixme[2]: Parameter must be annotated.
106-
def __init__(self, num_in, num_hidden, num_out) -> None:
89+
def __init__(self, num_in: int, num_hidden: int, num_out: int) -> None:
10790
super().__init__()
108-
# pyre-fixme[4]: Attribute must be annotated.
10991
self.num_in = num_in
110-
# pyre-fixme[4]: Attribute must be annotated.
11192
self.num_hidden = num_hidden
112-
# pyre-fixme[4]: Attribute must be annotated.
11393
self.num_out = num_out
11494
self.lin1 = nn.Linear(num_in, num_hidden)
11595
self.lin2 = nn.Linear(num_hidden, num_hidden)
@@ -121,9 +101,7 @@ def __init__(self, num_in, num_hidden, num_out) -> None:
121101
self.relu2 = nn.ReLU()
122102
self.softmax = nn.Softmax(dim=1)
123103

124-
# pyre-fixme[3]: Return type must be annotated.
125-
# pyre-fixme[2]: Parameter must be annotated.
126-
def forward(self, input):
104+
def forward(self, input: torch.Tensor) -> torch.Tensor:
127105
lin1 = self.relu1(self.lin1(input))
128106
lin2 = self.relu2(self.lin2(lin1))
129107
lin3 = self.lin3(lin2)

0 commit comments

Comments
 (0)