Skip to content

Commit 3e43021

Browse files
alanhdufacebook-github-bot
authored andcommitted
Move torchtnt to arc pyre (#1017)
Summary: Pull Request resolved: #1017 Turn `python.set_typing(True)` on and remove the old pyre configuration. This opts into the per-target type-checking (e.g. `arc pyre check-changed-targets`). See https://www.internalfb.com/wiki/Python/Type-annotations-in-python/How_To%3A_Migrate_to_Pyre_Fast_By_Default_Per_Target_Type_Checking/ Reviewed By: diego-urgell Differential Revision: D78682161 fbshipit-source-id: 5d740b4d82481d06c089d7d9e5c7ae05683ac436
1 parent 763e572 commit 3e43021

File tree

13 files changed

+12
-44
lines changed

13 files changed

+12
-44
lines changed

tests/framework/callbacks/test_csv_writer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def get_step_output_rows(
2626
self,
2727
state: State,
2828
unit: PredictUnit[TPredictData],
29-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
3029
step_output: Any,
3130
) -> Union[List[str], List[List[str]]]:
3231
return [["1"], ["2"]]
@@ -37,7 +36,6 @@ def get_step_output_rows(
3736
self,
3837
state: State,
3938
unit: PredictUnit[TPredictData],
40-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
4139
step_output: Any,
4240
) -> Union[List[str], List[List[str]]]:
4341
return ["1"]

tests/framework/test_auto_unit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
1717
from torch import nn
18-
1918
from torch.distributed import GradBucket
2019
from torchtnt.framework._test_utils import (
2120
DummyAutoUnit,
@@ -456,7 +455,6 @@ def custom_noop_hook(
456455
) -> torch.futures.Future[torch.Tensor]:
457456
nonlocal custom_noop_hook_called
458457

459-
# pyre-fixme[29]: `Type[torch.futures.Future]` is not a function.
460458
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
461459
fut.set_result(bucket.buffer())
462460
custom_noop_hook_called = True

tests/utils/test_memory.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,16 +220,17 @@ def __init__(self) -> None:
220220
len(tensor_map), 2 * len(inputs.metric_list[0].window_buffer.buffers) + 6
221221
)
222222
for metric in inputs.metric_list:
223-
self.assertTrue(metric.x in tensor_map)
223+
metric = cast(RandomModule, metric)
224+
self.assertIn(metric.x, tensor_map)
224225
self.assertEqual(
225226
tensor_map[metric.x], metric.x.size().numel() * metric.x.element_size()
226227
)
227-
self.assertTrue(metric.y[0] in tensor_map)
228+
self.assertIn(metric.y[0], tensor_map)
228229
self.assertEqual(
229230
tensor_map[metric.y[0]],
230231
metric.y[0].size().numel() * metric.y[0].element_size(),
231232
)
232-
self.assertTrue(metric.y[1] in tensor_map)
233+
self.assertIn(metric.y[1], tensor_map)
233234
self.assertEqual(
234235
tensor_map[metric.y[1]],
235236
metric.y[1].size().numel() * metric.y[1].element_size(),

torchtnt/framework/auto_unit.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ class TrainStepResults:
131131

132132
loss: torch.Tensor
133133
total_grad_norm: Optional[torch.Tensor]
134-
# pyre-fixme[4]: Attribute `outputs` of class `TrainStepResults` must have a type other than `Any`.
135134
outputs: Any
136135

137136

@@ -371,7 +370,6 @@ def __init__(
371370
global_mesh=global_mesh,
372371
)
373372

374-
# pyre-fixme[3]: Return annotation cannot be `Any`.
375373
def predict_step(self, state: State, data: TPredictData) -> Any:
376374
# if detect_anomaly is true, run forward pass under detect_anomaly context
377375
detect_anomaly = self.detect_anomaly
@@ -394,7 +392,6 @@ def on_predict_step_end(
394392
state: State,
395393
data: TPredictData,
396394
step: int,
397-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
398395
outputs: Any,
399396
) -> None:
400397
"""
@@ -645,7 +642,6 @@ def configure_optimizers_and_lr_scheduler(
645642
...
646643

647644
@abstractmethod
648-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
649645
def compute_loss(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
650646
"""
651647
The user should implement this method with their loss computation. This will be called every ``train_step``/``eval_step``.
@@ -662,7 +658,6 @@ def compute_loss(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
662658
"""
663659
...
664660

665-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
666661
def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
667662
should_update_weights = (
668663
self.train_progress.num_steps_completed_in_epoch + 1
@@ -779,7 +774,6 @@ def on_train_epoch_end(self, state: State) -> None:
779774

780775
self._is_last_batch = False
781776

782-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
783777
def eval_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
784778
with self.maybe_autocast_precision:
785779
# users must override this
@@ -800,7 +794,6 @@ def on_eval_step_end(
800794
data: TData,
801795
step: int,
802796
loss: torch.Tensor,
803-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
804797
outputs: Any,
805798
) -> None:
806799
"""
@@ -816,7 +809,6 @@ def on_eval_step_end(
816809
"""
817810
pass
818811

819-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
820812
def predict_step(self, state: State, data: TData) -> Any:
821813
with self.maybe_autocast_precision:
822814
with get_timing_context(state, f"{self.__class__.__name__}.forward"):
@@ -835,7 +827,6 @@ def on_predict_step_end(
835827
state: State,
836828
data: TData,
837829
step: int,
838-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
839830
outputs: Any,
840831
) -> None:
841832
"""

torchtnt/framework/callbacks/base_csv_writer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import Any, List, TextIO, Union
1313

1414
from pyre_extensions import none_throws
15-
1615
from torchtnt.framework.callback import Callback
1716
from torchtnt.framework.state import EntryPoint, State
1817
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
@@ -61,7 +60,6 @@ def get_step_output_rows(
6160
self,
6261
state: State,
6362
unit: TPredictUnit,
64-
# pyre-fixme: Missing parameter annotation [2]
6563
step_output: Any,
6664
) -> Union[List[str], List[List[str]]]: ...
6765

torchtnt/framework/callbacks/module_summary.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
process_fn: Callable[
4646
[List[ModuleSummaryObj]], None
4747
] = _log_module_summary_tables,
48-
# pyre-fixme
4948
module_inputs: Optional[
5049
MutableMapping[str, Tuple[Tuple[Any, ...], Dict[str, Any]]]
5150
] = None,

torchtnt/framework/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ def _train_epoch_impl(
247247
):
248248
_evaluate_impl(
249249
state,
250+
# pyre-fixme[6]: For 2nd argument expected `EvalUnit[Any]` but
251+
# got `TrainUnit[Any]`.
250252
train_unit,
251253
callback_handler,
252254
)
@@ -293,6 +295,8 @@ def _train_epoch_impl(
293295
):
294296
_evaluate_impl(
295297
state,
298+
# pyre-fixme[6]: For 2nd argument expected `EvalUnit[Any]` but got
299+
# `TrainUnit[Any]`.
296300
train_unit,
297301
callback_handler,
298302
)

torchtnt/framework/unit.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def on_train_epoch_start(self, state: State) -> None:
362362
pass
363363

364364
@abstractmethod
365-
# pyre-fixme[3]: Return annotation cannot be `Any`.
366365
def train_step(self, state: State, data: TTrainData) -> Any:
367366
"""Core required method for user to implement. This method will be called at each iteration of the
368367
train dataloader, and can return any data the user wishes.
@@ -476,7 +475,6 @@ def on_eval_epoch_start(self, state: State) -> None:
476475
pass
477476

478477
@abstractmethod
479-
# pyre-fixme[3]: Return annotation cannot be `Any`.
480478
def eval_step(self, state: State, data: TEvalData) -> Any:
481479
"""
482480
Core required method for user to implement. This method will be called at each iteration of the
@@ -597,7 +595,6 @@ def on_predict_epoch_start(self, state: State) -> None:
597595
pass
598596

599597
@abstractmethod
600-
# pyre-fixme[3]: Return annotation cannot be `Any`.
601598
def predict_step(self, state: State, data: TPredictData) -> Any:
602599
"""
603600
Core required method for user to implement. This method will be called at each iteration of the

torchtnt/utils/data/iterators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from python.migrations.py310 import StrEnum310
3333

3434
try:
35-
# pyre-ignore[21]: Could not find name `StrEnum` in `enum`
3635
from enum import StrEnum
3736
except ImportError:
3837

torchtnt/utils/flops.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def _conv_backward_flop_jit(
147147
return flop_count
148148

149149

150-
# pyre-fixme [5]
151150
flop_mapping: Dict[Callable[..., Any], Callable[[Tuple[Any], Tuple[Any]], Number]] = {
152151
aten.mm: _matmul_flop_jit,
153152
aten.matmul: _matmul_flop_jit,
@@ -224,8 +223,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
224223

225224
def __torch_dispatch__(
226225
self,
227-
func: Callable[..., Any], # pyre-fixme [2] func can be any func
228-
types: Tuple[Any], # pyre-fixme [2]
226+
func: Callable[..., Any],
227+
types: Tuple[Any],
229228
args=(), # pyre-fixme [2]
230229
kwargs=None, # pyre-fixme [2]
231230
) -> PyTree:
@@ -242,7 +241,6 @@ def __torch_dispatch__(
242241

243242
return rs
244243

245-
# pyre-fixme [3]
246244
def _create_backwards_push(self, name: str) -> Callable[..., Any]:
247245
class PushState(torch.autograd.Function):
248246
@staticmethod
@@ -265,7 +263,6 @@ def backward(ctx, *grad_outs):
265263
# using a function parameter.
266264
return PushState.apply
267265

268-
# pyre-fixme [3]
269266
def _create_backwards_pop(self, name: str) -> Callable[..., Any]:
270267
class PopState(torch.autograd.Function):
271268
@staticmethod
@@ -289,9 +286,8 @@ def backward(ctx, *grad_outs):
289286
# using a function parameter.
290287
return PopState.apply
291288

292-
# pyre-fixme [3] Return a callable function
293289
def _enter_module(self, name: str) -> Callable[..., Any]:
294-
# pyre-fixme [2, 3]
290+
# pyre-fixme [3]
295291
def f(module: torch.nn.Module, inputs: Tuple[Any]):
296292
parents = self._parents
297293
parents.append(name)
@@ -301,9 +297,8 @@ def f(module: torch.nn.Module, inputs: Tuple[Any]):
301297

302298
return f
303299

304-
# pyre-fixme [3] Return a callable function
305300
def _exit_module(self, name: str) -> Callable[..., Any]:
306-
# pyre-fixme [2, 3]
301+
# pyre-fixme [3]
307302
def f(module: torch.nn.Module, inputs: Tuple[Any], outputs: Tuple[Any]):
308303
parents = self._parents
309304
assert parents[-1] == name

0 commit comments

Comments
 (0)