Skip to content

Commit a8692ba

Browse files
yeyingxiaofacebook-github-bot
authored andcommitted
Fix pyre typing in train.py (meta-pytorch#1641)
Summary: Pull Request resolved: meta-pytorch#1641 Fix type annotations in captum/_utils/models/linear_model/train.py and captum/testing/attr/helpers/test_config.py This diff addresses all `pyre-fixme` comments in train.py by adding proper Python type annotations and removes the fixme comments. Additionally, a type annotation was added in test_config.py to ensure OSS mypy checks pass. Verified with internal Pyre checks and OSS mypy tests. The changes improve static type safety and align with Captum OSS development guidelines. Reviewed By: craymichael Differential Revision: D81624358 fbshipit-source-id: 613b4981474ceaea3fd84b1a73151cafc2aee501
1 parent aff7603 commit a8692ba

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

captum/_utils/models/linear_model/train.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
import warnings
44
from functools import reduce
55
from types import ModuleType
6-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
6+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type
77

88
import torch
99
import torch.nn as nn
1010
from captum._utils.models.linear_model.model import LinearModel
1111
from torch.utils.data import DataLoader
1212

1313

14-
# pyre-fixme[2]: Parameter must be annotated.
15-
def l2_loss(x1, x2, weights=None) -> torch.Tensor:
14+
def l2_loss(
15+
x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None
16+
) -> torch.Tensor:
1617
if weights is None:
17-
return torch.mean((x1 - x2) ** 2) / 2.0
18+
return torch.mean(torch.pow(x1 - x2, 2)) / 2.0
1819
else:
19-
return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0
20+
return torch.sum((weights / weights.norm(p=1)) * torch.pow(x1 - x2, 2)) / 2.0
2021

2122

2223
class ConvergenceTracker:
@@ -60,20 +61,19 @@ def average(self) -> torch.Tensor:
6061

6162

6263
def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None:
63-
assert model.linear is not None
64+
linear_layer = model.linear
65+
assert linear_layer is not None
6466
if init_scheme is not None:
6567
assert init_scheme in ["xavier", "zeros"]
6668

6769
with torch.no_grad():
6870
if init_scheme == "xavier":
69-
# pyre-fixme[16]: `Optional` has no attribute `weight`.
70-
torch.nn.init.xavier_uniform_(model.linear.weight)
71+
torch.nn.init.xavier_uniform_(linear_layer.weight)
7172
else:
72-
model.linear.weight.zero_()
73+
linear_layer.weight.zero_()
7374

74-
# pyre-fixme[16]: `Optional` has no attribute `bias`.
75-
if model.linear.bias is not None:
76-
model.linear.bias.zero_()
75+
if linear_layer.bias is not None:
76+
linear_layer.bias.zero_()
7777

7878

7979
def _get_point(
@@ -103,8 +103,9 @@ def sgd_train_linear_model(
103103
reduce_lr: bool = True,
104104
initial_lr: float = 0.01,
105105
alpha: float = 1.0,
106-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
107-
loss_fn: Callable = l2_loss,
106+
loss_fn: Callable[
107+
[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor
108+
] = l2_loss,
108109
reg_term: Optional[int] = 1,
109110
patience: int = 10,
110111
threshold: float = 1e-4,
@@ -224,8 +225,8 @@ def sgd_train_linear_model(
224225

225226
loss = loss_fn(y, out, w)
226227
if reg_term is not None:
227-
# pyre-fixme[16]: `Optional` has no attribute `weight`.
228-
reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore
228+
assert model.linear is not None
229+
reg = torch.norm(model.linear.weight, p=reg_term)
229230
loss += reg.sum() * alpha
230231

231232
loss_window.append(loss.clone().detach())
@@ -269,18 +270,19 @@ def sgd_train_linear_model(
269270

270271

271272
class NormLayer(nn.Module):
272-
# pyre-fixme[2]: Parameter must be annotated.
273-
def __init__(self, mean, std, n=None, eps: float = 1e-8) -> None:
273+
def __init__(
274+
self,
275+
mean: torch.Tensor,
276+
std: torch.Tensor,
277+
n: Optional[int] = None,
278+
eps: float = 1e-8,
279+
) -> None:
274280
super().__init__()
275-
# pyre-fixme[4]: Attribute must be annotated.
276-
self.mean = mean
277-
# pyre-fixme[4]: Attribute must be annotated.
278-
self.std = std
281+
self.mean: torch.Tensor = mean
282+
self.std: torch.Tensor = std
279283
self.eps = eps
280284

281-
# pyre-fixme[3]: Return type must be annotated.
282-
# pyre-fixme[2]: Parameter must be annotated.
283-
def forward(self, x):
285+
def forward(self, x: torch.Tensor) -> torch.Tensor:
284286
return (x - self.mean) / (self.std + self.eps)
285287

286288

@@ -371,16 +373,23 @@ def sklearn_train_linear_model(
371373
else:
372374
w = None
373375

376+
mean, std = None, None
374377
if norm_input:
375378
mean, std = x.mean(0), x.std(0)
376379
x -= mean
377380
x /= std
378381

379382
t1 = time.time()
380-
# pyre-fixme[29]: `str` is not a function.
381-
sklearn_model = reduce( # type: ignore
382-
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501
383-
)(**construct_kwargs)
383+
# Start with the sklearn module and navigate through the attribute path
384+
sklearn_cls = cast(
385+
Type[Any],
386+
reduce(
387+
lambda obj, attr: getattr(obj, attr),
388+
sklearn_trainer.split("."),
389+
sklearn,
390+
),
391+
)
392+
sklearn_model = sklearn_cls(**construct_kwargs)
384393
try:
385394
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
386395
except TypeError:
@@ -417,8 +426,7 @@ def sklearn_train_linear_model(
417426
)
418427

419428
if norm_input:
420-
# pyre-fixme[61]: `mean` is undefined, or not always defined.
421-
# pyre-fixme[61]: `std` is undefined, or not always defined.
429+
assert mean is not None and std is not None
422430
model.norm = NormLayer(mean, std)
423431

424432
return {"train_time": t2 - t1}

captum/testing/attr/helpers/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4+
from typing import Any, Dict, List
45

56
import torch
67
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
@@ -89,8 +90,7 @@
8990
# Set random seeds to ensure deterministic behavior
9091
set_all_random_seeds(1234)
9192

92-
# pyre-fixme[5]: Global expression must be annotated.
93-
config = [
93+
config: List[Dict[str, Any]] = [
9494
# Attribution Method Configs
9595
# Primary Methods (Generic Configs)
9696
{

0 commit comments

Comments
 (0)