Skip to content

Commit f0955d0

Browse files
authored
[Typing] Fix pyrefly-ignore in train.py (#2282)
The main goal is to fix train.py. Since train.py uses many components, this PR also fixes some other components as well. Trainer.model_parts should be `list[ModelProtocol]` but this change will affect many files. Will do this in another PR. Some ignores are already useless, not sure why those ignores are not removed previously using `pyrefly check --remove-unused-ignores`. Total pyrefly ignores removed: 19
1 parent a33d0e3 commit f0955d0

File tree

8 files changed

+35
-41
lines changed

8 files changed

+35
-41
lines changed

torchtitan/components/dataloader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import inspect
1010
import pickle
1111
from abc import ABC, abstractmethod
12-
from typing import Any
12+
from typing import Any, Iterator
13+
14+
import torch
1315

1416
from torch.distributed.checkpoint.stateful import Stateful
1517
from torch.utils.data import IterableDataset
@@ -38,11 +40,10 @@ class BaseDataLoader(Stateful, ABC):
3840
"""
3941

4042
@abstractmethod
41-
def __iter__(self):
43+
def __iter__(self) -> Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]]:
4244
...
4345

4446

45-
# pyrefly: ignore [inconsistent-inheritance]
4647
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
4748
"""Dataloader that is aware of distributed data parallelism.
4849

torchtitan/components/ft/manager.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77
import importlib.util
88
from contextlib import nullcontext
99
from datetime import timedelta
10-
from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union
10+
from typing import Callable, cast, ContextManager, Optional, TYPE_CHECKING, Union
1111

1212
import torch
1313
import torch.distributed as dist
1414

1515
import torch.nn as nn
1616
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1717
from torch.distributed.distributed_c10d import ReduceOp
18-
from torchtitan.components.ft.config import FaultTolerance as FTConfig
18+
from torchtitan.components.ft.config import FaultTolerance as ExtendedFTConfig
19+
from torchtitan.config import FaultTolerance as FTConfig
1920
from torchtitan.tools.logging import logger
2021

2122
if importlib.util.find_spec("torchft") is not None:
@@ -119,8 +120,9 @@ def maybe_semi_sync_training(
119120
"""
120121
If TorchFT is enabled and the config is set, use semi_sync_method
121122
"""
122-
semi_sync_method = ft_config.semi_sync_method
123-
if ft_config.enable and semi_sync_method is not None:
123+
extend_ft_config = cast(ExtendedFTConfig, ft_config)
124+
semi_sync_method = extend_ft_config.semi_sync_method
125+
if extend_ft_config.enable and semi_sync_method is not None:
124126
from torchft import local_sgd
125127

126128
assert (
@@ -131,7 +133,7 @@ def maybe_semi_sync_training(
131133
)
132134
if semi_sync_method.lower() == "diloco":
133135
if fragment_fn:
134-
model_parts = fragment_fn(model, ft_config, n_layers)
136+
model_parts = fragment_fn(model, extend_ft_config, n_layers)
135137
else:
136138
model_parts = [model]
137139

@@ -149,17 +151,17 @@ def maybe_semi_sync_training(
149151
model_fragments=model_parts,
150152
inner_optimizer=optimizer,
151153
outer_optimizer=outer_optimizers,
152-
sync_every=ft_config.sync_steps,
153-
should_quantize=ft_config.should_quantize,
154-
fragment_sync_delay=ft_config.fragment_sync_delay,
155-
fragment_update_alpha=ft_config.fragment_update_alpha,
154+
sync_every=extend_ft_config.sync_steps,
155+
should_quantize=extend_ft_config.should_quantize,
156+
fragment_sync_delay=extend_ft_config.fragment_sync_delay,
157+
fragment_update_alpha=extend_ft_config.fragment_update_alpha,
156158
)
157159
elif semi_sync_method.lower() == "local_sgd":
158160
return local_sgd.LocalSGD(
159161
manager=ft_manager._manager,
160162
model=model,
161163
optimizer=optimizer,
162-
sync_every=ft_config.sync_steps,
164+
sync_every=extend_ft_config.sync_steps,
163165
)
164166
else:
165167
raise ValueError(

torchtitan/components/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchtitan.distributed import ParallelDims
1919
from torchtitan.tools import utils
2020
from torchtitan.tools.logging import logger
21-
from torchtitan.tools.utils import Color, device_module, device_type
21+
from torchtitan.tools.utils import Color, device_module, device_type, NoColor
2222

2323
if TYPE_CHECKING:
2424
from torchtitan.protocols import BaseModelArgs
@@ -195,7 +195,7 @@ def close(self) -> None:
195195

196196

197197
def ensure_pp_loss_visible(
198-
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
198+
parallel_dims: ParallelDims, job_config: JobConfig, color: Color | NoColor
199199
) -> None:
200200
"""
201201
Ensures that the loss is visible on the console for pipeline-parallel training.

torchtitan/components/validate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class BaseValidator:
2929
def __init__(self, job_config: JobConfig):
3030
self.job_config = job_config
3131

32-
def validate(self, model_parts: list[nn.Module]) -> dict[str, float]:
32+
def validate(self, model_parts: list[nn.Module], step: int) -> None:
3333
raise NotImplementedError("validate method not implemented")
3434

3535
def should_validate(self, step: int) -> bool:
@@ -154,7 +154,6 @@ def post_dataloading_process(
154154
return inputs, labels, extra_inputs, extra_kwargs
155155

156156
@torch.no_grad()
157-
# pyrefly: ignore [bad-override]
158157
def validate(
159158
self,
160159
model_parts: list[nn.Module],
@@ -170,7 +169,6 @@ def validate(
170169
device_type = utils.device_type
171170
num_steps = 0
172171

173-
# pyrefly: ignore [not-iterable]
174172
for input_dict, labels in self.validation_dataloader:
175173
if (
176174
self.job_config.validation.steps != -1
@@ -190,7 +188,6 @@ def validate(
190188

191189
# Count valid tokens for this batch
192190
local_valid_tokens = torch.tensor(0, dtype=torch.int64, device=device_type)
193-
# pyrefly: ignore [missing-attribute]
194191
local_valid_tokens += (labels != IGNORE_INDEX).sum()
195192

196193
# All-reduce token count across DP ranks to get global token count

torchtitan/models/attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,10 @@ def forward(
6868
max_q = attention_masks.max_q
6969
max_k = attention_masks.max_k
7070

71-
# pyrefly: ignore [no-matching-overload]
7271
xq_packed = xq.transpose(1, 2).flatten(0, 1) # (bs * seqlen, n_heads, head_dim)
73-
# pyrefly: ignore [no-matching-overload]
7472
xk_packed = xk.transpose(1, 2).flatten(
7573
0, 1
7674
) # (bs * seqlen, n_kv_heads, head_dim)
77-
# pyrefly: ignore [no-matching-overload]
7875
xv_packed = xv.transpose(1, 2).flatten(
7976
0, 1
8077
) # (bs * seqlen, n_kv_heads, head_dim)

torchtitan/models/flux/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def forward_backward_step(
138138
# pyrefly: ignore [bad-assignment]
139139
global_valid_tokens = dist_utils.dist_sum(local_valid_tokens, batch_mesh)
140140
else:
141-
# pyrefly: ignore [bad-assignment]
142141
global_valid_tokens = local_valid_tokens.float()
143142

144143
# Keep these variables local to shorten the code as these are

torchtitan/models/flux/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def validate(
127127
device_type = dist_utils.device_type
128128
num_steps = 0
129129

130-
# pyrefly: ignore [not-iterable]
131130
for input_dict, labels in self.validation_dataloader:
132131
if (
133132
self.job_config.validation.steps != -1
@@ -139,6 +138,7 @@ def validate(
139138
if not isinstance(prompt, list):
140139
prompt = [prompt]
141140
for p in prompt:
141+
assert isinstance(p, str), f"prompt must be a string, got {type(p)}"
142142
if save_img_count != -1 and save_img_count <= 0:
143143
break
144144
image = generate_image(

torchtitan/train.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import time
1212
from datetime import timedelta
13-
from typing import Any, Iterable
13+
from typing import Any, cast, Iterable, Iterator
1414

1515
import torch
1616
import torch.distributed.checkpoint.stateful
@@ -28,6 +28,7 @@
2828
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
2929
from torchtitan.distributed import ParallelDims, utils as dist_utils
3030
from torchtitan.distributed.context_parallel import prepare_context_parallel_input
31+
from torchtitan.protocols import ModelProtocol
3132
from torchtitan.protocols.model_converter import build_model_converters
3233
from torchtitan.tools import utils
3334
from torchtitan.tools.logging import init_logger, logger
@@ -46,6 +47,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4647
# swappable training components in TrainSpec
4748
tokenizer: train_spec_module.BaseTokenizer | None
4849
dataloader: train_spec_module.BaseDataLoader
50+
# TODO: we should make this list[ModelProtocol] but this will affect many components.
51+
# will do this in a separate PR
4952
model_parts: list[torch.nn.Module]
5053
loss_fn: train_spec_module.LossFunction
5154
optimizers: train_spec_module.OptimizersContainer
@@ -97,7 +100,6 @@ def __init__(self, job_config: JobConfig):
97100
else:
98101
batch_degree, batch_rank = 1, 0
99102

100-
# pyrefly: ignore [bad-argument-type]
101103
self.ft_manager = FTManager(job_config.fault_tolerance)
102104
batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank)
103105

@@ -173,12 +175,13 @@ def __init__(self, job_config: JobConfig):
173175
)
174176

175177
# move sharded model to CPU/GPU and initialize weights via DTensor
178+
buffer_device: torch.device | None
176179
if job_config.checkpoint.create_seed_checkpoint:
177180
init_device = "cpu"
178181
buffer_device = None
179182
elif job_config.training.enable_cpu_offload:
180183
init_device = "cpu"
181-
buffer_device = device_type
184+
buffer_device = torch.device(device_type)
182185
else:
183186
init_device = device_type
184187
buffer_device = None
@@ -239,21 +242,18 @@ def __init__(self, job_config: JobConfig):
239242
for m in self.model_parts:
240243
m.to_empty(device=init_device)
241244
with torch.no_grad():
242-
# pyrefly: ignore [not-callable]
243-
m.init_weights(buffer_device=buffer_device)
245+
cast(ModelProtocol, m).init_weights(buffer_device=buffer_device)
244246
m.train()
245247

246248
# confirm that user will be able to view loss metrics on the console
247-
# pyrefly: ignore [bad-argument-type]
248249
ensure_pp_loss_visible(parallel_dims, job_config, color)
249250
else:
250251
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
251252
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
252253

253254
model.to_empty(device=init_device)
254255
with torch.no_grad():
255-
# pyrefly: ignore [not-callable]
256-
model.init_weights(buffer_device=buffer_device)
256+
cast(ModelProtocol, model).init_weights(buffer_device=buffer_device)
257257
model.train()
258258

259259
self.model_parts = [model]
@@ -384,7 +384,7 @@ def init_distributed(self) -> ParallelDims:
384384

385385
def batch_generator(
386386
self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
387-
) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
387+
) -> Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]]:
388388
"""Returns an iterator that processes batches from the data iterator.
389389
390390
Note: Tensors are yielded on CPU. The caller is responsible for moving
@@ -457,8 +457,11 @@ def post_dataloading_process(
457457

458458
attn_type = getattr(self.model_args, "attn_type", "sdpa")
459459
if attn_type in ["flex", "varlen"]:
460-
# pyrefly: ignore [not-callable]
461-
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
460+
assert (
461+
self.tokenizer is not None
462+
), "tokenizer is required for flex/varlen attention"
463+
model = cast(ModelProtocol, self.model_parts[0])
464+
extra_kwargs["attention_masks"] = model.get_attention_masks(
462465
input_batch=inputs,
463466
tokenizer=self.tokenizer,
464467
extra_inputs=extra_inputs,
@@ -543,7 +546,7 @@ def forward_backward_step(
543546
return loss
544547

545548
def train_step(
546-
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
549+
self, data_iterator: Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]]
547550
):
548551
self.optimizers.zero_grad()
549552
# Save the current step learning rate for logging
@@ -557,9 +560,7 @@ def train_step(
557560
microbatches = []
558561
local_valid_tokens = torch.tensor(0, dtype=torch.int64)
559562
for _microbatch in range(self.gradient_accumulation_steps):
560-
# pyrefly: ignore [no-matching-overload]
561563
input_dict, labels = next(data_iterator)
562-
# pyrefly: ignore [missing-attribute]
563564
local_valid_tokens += (labels != IGNORE_INDEX).sum()
564565
microbatches.append((input_dict, labels))
565566

@@ -668,7 +669,6 @@ def train(self):
668669
leaf_folder=leaf_folder,
669670
) as memory_profiler,
670671
maybe_semi_sync_training(
671-
# pyrefly: ignore [bad-argument-type]
672672
job_config.fault_tolerance,
673673
ft_manager=self.ft_manager,
674674
model=self.model_parts[0],
@@ -685,7 +685,6 @@ def train(self):
685685
),
686686
),
687687
):
688-
# pyrefly: ignore [bad-argument-type]
689688
data_iterator = self.batch_generator(self.dataloader)
690689
while self.should_continue_training():
691690
self.step += 1
@@ -705,7 +704,6 @@ def train(self):
705704
self.job_config.validation.enable
706705
and self.validator.should_validate(self.step)
707706
):
708-
# pyrefly: ignore [bad-argument-count]
709707
self.validator.validate(self.model_parts, self.step)
710708

711709
# signal the profiler that the next profiling step has started

0 commit comments

Comments
 (0)