|
3 | 3 | import warnings
|
4 | 4 | from functools import reduce
|
5 | 5 | from types import ModuleType
|
6 |
| -from typing import Any, Callable, cast, Dict, List, Optional, Tuple |
| 6 | +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type |
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | import torch.nn as nn
|
10 | 10 | from captum._utils.models.linear_model.model import LinearModel
|
11 | 11 | from torch.utils.data import DataLoader
|
12 | 12 |
|
13 | 13 |
|
14 |
| -# pyre-fixme[2]: Parameter must be annotated. |
15 |
| -def l2_loss(x1, x2, weights=None) -> torch.Tensor: |
| 14 | +def l2_loss( |
| 15 | + x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None |
| 16 | +) -> torch.Tensor: |
16 | 17 | if weights is None:
|
17 |
| - return torch.mean((x1 - x2) ** 2) / 2.0 |
| 18 | + return torch.mean(torch.pow(x1 - x2, 2)) / 2.0 |
18 | 19 | else:
|
19 |
| - return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0 |
| 20 | + return torch.sum((weights / weights.norm(p=1)) * torch.pow(x1 - x2, 2)) / 2.0 |
20 | 21 |
|
21 | 22 |
|
22 | 23 | class ConvergenceTracker:
|
@@ -60,20 +61,19 @@ def average(self) -> torch.Tensor:
|
60 | 61 |
|
61 | 62 |
|
62 | 63 | def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None:
|
63 |
| - assert model.linear is not None |
| 64 | + linear_layer = model.linear |
| 65 | + assert linear_layer is not None |
64 | 66 | if init_scheme is not None:
|
65 | 67 | assert init_scheme in ["xavier", "zeros"]
|
66 | 68 |
|
67 | 69 | with torch.no_grad():
|
68 | 70 | if init_scheme == "xavier":
|
69 |
| - # pyre-fixme[16]: `Optional` has no attribute `weight`. |
70 |
| - torch.nn.init.xavier_uniform_(model.linear.weight) |
| 71 | + torch.nn.init.xavier_uniform_(linear_layer.weight) |
71 | 72 | else:
|
72 |
| - model.linear.weight.zero_() |
| 73 | + linear_layer.weight.zero_() |
73 | 74 |
|
74 |
| - # pyre-fixme[16]: `Optional` has no attribute `bias`. |
75 |
| - if model.linear.bias is not None: |
76 |
| - model.linear.bias.zero_() |
| 75 | + if linear_layer.bias is not None: |
| 76 | + linear_layer.bias.zero_() |
77 | 77 |
|
78 | 78 |
|
79 | 79 | def _get_point(
|
@@ -103,8 +103,9 @@ def sgd_train_linear_model(
|
103 | 103 | reduce_lr: bool = True,
|
104 | 104 | initial_lr: float = 0.01,
|
105 | 105 | alpha: float = 1.0,
|
106 |
| - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
107 |
| - loss_fn: Callable = l2_loss, |
| 106 | + loss_fn: Callable[ |
| 107 | + [torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor |
| 108 | + ] = l2_loss, |
108 | 109 | reg_term: Optional[int] = 1,
|
109 | 110 | patience: int = 10,
|
110 | 111 | threshold: float = 1e-4,
|
@@ -224,8 +225,8 @@ def sgd_train_linear_model(
|
224 | 225 |
|
225 | 226 | loss = loss_fn(y, out, w)
|
226 | 227 | if reg_term is not None:
|
227 |
| - # pyre-fixme[16]: `Optional` has no attribute `weight`. |
228 |
| - reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore |
| 228 | + assert model.linear is not None |
| 229 | + reg = torch.norm(model.linear.weight, p=reg_term) |
229 | 230 | loss += reg.sum() * alpha
|
230 | 231 |
|
231 | 232 | loss_window.append(loss.clone().detach())
|
@@ -269,18 +270,19 @@ def sgd_train_linear_model(
|
269 | 270 |
|
270 | 271 |
|
271 | 272 | class NormLayer(nn.Module):
|
272 |
| - # pyre-fixme[2]: Parameter must be annotated. |
273 |
| - def __init__(self, mean, std, n=None, eps: float = 1e-8) -> None: |
| 273 | + def __init__( |
| 274 | + self, |
| 275 | + mean: torch.Tensor, |
| 276 | + std: torch.Tensor, |
| 277 | + n: Optional[int] = None, |
| 278 | + eps: float = 1e-8, |
| 279 | + ) -> None: |
274 | 280 | super().__init__()
|
275 |
| - # pyre-fixme[4]: Attribute must be annotated. |
276 |
| - self.mean = mean |
277 |
| - # pyre-fixme[4]: Attribute must be annotated. |
278 |
| - self.std = std |
| 281 | + self.mean: torch.Tensor = mean |
| 282 | + self.std: torch.Tensor = std |
279 | 283 | self.eps = eps
|
280 | 284 |
|
281 |
| - # pyre-fixme[3]: Return type must be annotated. |
282 |
| - # pyre-fixme[2]: Parameter must be annotated. |
283 |
| - def forward(self, x): |
| 285 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
284 | 286 | return (x - self.mean) / (self.std + self.eps)
|
285 | 287 |
|
286 | 288 |
|
@@ -371,16 +373,23 @@ def sklearn_train_linear_model(
|
371 | 373 | else:
|
372 | 374 | w = None
|
373 | 375 |
|
| 376 | + mean, std = None, None |
374 | 377 | if norm_input:
|
375 | 378 | mean, std = x.mean(0), x.std(0)
|
376 | 379 | x -= mean
|
377 | 380 | x /= std
|
378 | 381 |
|
379 | 382 | t1 = time.time()
|
380 |
| - # pyre-fixme[29]: `str` is not a function. |
381 |
| - sklearn_model = reduce( # type: ignore |
382 |
| - lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501 |
383 |
| - )(**construct_kwargs) |
| 383 | + # Start with the sklearn module and navigate through the attribute path |
| 384 | + sklearn_cls = cast( |
| 385 | + Type[Any], |
| 386 | + reduce( |
| 387 | + lambda obj, attr: getattr(obj, attr), |
| 388 | + sklearn_trainer.split("."), |
| 389 | + sklearn, |
| 390 | + ), |
| 391 | + ) |
| 392 | + sklearn_model = sklearn_cls(**construct_kwargs) |
384 | 393 | try:
|
385 | 394 | sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
|
386 | 395 | except TypeError:
|
@@ -417,8 +426,7 @@ def sklearn_train_linear_model(
|
417 | 426 | )
|
418 | 427 |
|
419 | 428 | if norm_input:
|
420 |
| - # pyre-fixme[61]: `mean` is undefined, or not always defined. |
421 |
| - # pyre-fixme[61]: `std` is undefined, or not always defined. |
| 429 | + assert mean is not None and std is not None |
422 | 430 | model.norm = NormLayer(mean, std)
|
423 | 431 |
|
424 | 432 | return {"train_time": t2 - t1}
|
0 commit comments