1010import os
1111import time
1212from datetime import timedelta
13- from typing import Any , Iterable
13+ from typing import Any , cast , Iterable , Iterator
1414
1515import torch
1616import torch .distributed .checkpoint .stateful
2828from torchtitan .config import ConfigManager , JobConfig , TORCH_DTYPE_MAP
2929from torchtitan .distributed import ParallelDims , utils as dist_utils
3030from torchtitan .distributed .context_parallel import prepare_context_parallel_input
31+ from torchtitan .protocols import ModelProtocol
3132from torchtitan .protocols .model_converter import build_model_converters
3233from torchtitan .tools import utils
3334from torchtitan .tools .logging import init_logger , logger
@@ -46,6 +47,8 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4647 # swappable training components in TrainSpec
4748 tokenizer : train_spec_module .BaseTokenizer | None
4849 dataloader : train_spec_module .BaseDataLoader
50+ # TODO: we should make this list[ModelProtocol] but this will affect many components.
51+ # will do this in a separate PR
4952 model_parts : list [torch .nn .Module ]
5053 loss_fn : train_spec_module .LossFunction
5154 optimizers : train_spec_module .OptimizersContainer
@@ -97,7 +100,6 @@ def __init__(self, job_config: JobConfig):
97100 else :
98101 batch_degree , batch_rank = 1 , 0
99102
100- # pyrefly: ignore [bad-argument-type]
101103 self .ft_manager = FTManager (job_config .fault_tolerance )
102104 batch_degree , batch_rank = self .ft_manager .get_dp_info (batch_degree , batch_rank )
103105
@@ -173,12 +175,13 @@ def __init__(self, job_config: JobConfig):
173175 )
174176
175177 # move sharded model to CPU/GPU and initialize weights via DTensor
178+ buffer_device : torch .device | None
176179 if job_config .checkpoint .create_seed_checkpoint :
177180 init_device = "cpu"
178181 buffer_device = None
179182 elif job_config .training .enable_cpu_offload :
180183 init_device = "cpu"
181- buffer_device = device_type
184+ buffer_device = torch . device ( device_type )
182185 else :
183186 init_device = device_type
184187 buffer_device = None
@@ -239,21 +242,18 @@ def __init__(self, job_config: JobConfig):
239242 for m in self .model_parts :
240243 m .to_empty (device = init_device )
241244 with torch .no_grad ():
242- # pyrefly: ignore [not-callable]
243- m .init_weights (buffer_device = buffer_device )
245+ cast (ModelProtocol , m ).init_weights (buffer_device = buffer_device )
244246 m .train ()
245247
246248 # confirm that user will be able to view loss metrics on the console
247- # pyrefly: ignore [bad-argument-type]
248249 ensure_pp_loss_visible (parallel_dims , job_config , color )
249250 else :
250251 # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
251252 model = self .train_spec .parallelize_fn (model , parallel_dims , job_config )
252253
253254 model .to_empty (device = init_device )
254255 with torch .no_grad ():
255- # pyrefly: ignore [not-callable]
256- model .init_weights (buffer_device = buffer_device )
256+ cast (ModelProtocol , model ).init_weights (buffer_device = buffer_device )
257257 model .train ()
258258
259259 self .model_parts = [model ]
@@ -384,7 +384,7 @@ def init_distributed(self) -> ParallelDims:
384384
385385 def batch_generator (
386386 self , data_iterable : Iterable [tuple [dict [str , torch .Tensor ], torch .Tensor ]]
387- ) -> Iterable [tuple [dict [str , torch .Tensor ], torch .Tensor ]]:
387+ ) -> Iterator [tuple [dict [str , torch .Tensor ], torch .Tensor ]]:
388388 """Returns an iterator that processes batches from the data iterator.
389389
390390 Note: Tensors are yielded on CPU. The caller is responsible for moving
@@ -457,8 +457,11 @@ def post_dataloading_process(
457457
458458 attn_type = getattr (self .model_args , "attn_type" , "sdpa" )
459459 if attn_type in ["flex" , "varlen" ]:
460- # pyrefly: ignore [not-callable]
461- extra_kwargs ["attention_masks" ] = self .model_parts [0 ].get_attention_masks (
460+ assert (
461+ self .tokenizer is not None
462+ ), "tokenizer is required for flex/varlen attention"
463+ model = cast (ModelProtocol , self .model_parts [0 ])
464+ extra_kwargs ["attention_masks" ] = model .get_attention_masks (
462465 input_batch = inputs ,
463466 tokenizer = self .tokenizer ,
464467 extra_inputs = extra_inputs ,
@@ -543,7 +546,7 @@ def forward_backward_step(
543546 return loss
544547
545548 def train_step (
546- self , data_iterator : Iterable [tuple [dict [str , torch .Tensor ], torch .Tensor ]]
549+ self , data_iterator : Iterator [tuple [dict [str , torch .Tensor ], torch .Tensor ]]
547550 ):
548551 self .optimizers .zero_grad ()
549552 # Save the current step learning rate for logging
@@ -557,9 +560,7 @@ def train_step(
557560 microbatches = []
558561 local_valid_tokens = torch .tensor (0 , dtype = torch .int64 )
559562 for _microbatch in range (self .gradient_accumulation_steps ):
560- # pyrefly: ignore [no-matching-overload]
561563 input_dict , labels = next (data_iterator )
562- # pyrefly: ignore [missing-attribute]
563564 local_valid_tokens += (labels != IGNORE_INDEX ).sum ()
564565 microbatches .append ((input_dict , labels ))
565566
@@ -668,7 +669,6 @@ def train(self):
668669 leaf_folder = leaf_folder ,
669670 ) as memory_profiler ,
670671 maybe_semi_sync_training (
671- # pyrefly: ignore [bad-argument-type]
672672 job_config .fault_tolerance ,
673673 ft_manager = self .ft_manager ,
674674 model = self .model_parts [0 ],
@@ -685,7 +685,6 @@ def train(self):
685685 ),
686686 ),
687687 ):
688- # pyrefly: ignore [bad-argument-type]
689688 data_iterator = self .batch_generator (self .dataloader )
690689 while self .should_continue_training ():
691690 self .step += 1
@@ -705,7 +704,6 @@ def train(self):
705704 self .job_config .validation .enable
706705 and self .validator .should_validate (self .step )
707706 ):
708- # pyrefly: ignore [bad-argument-count]
709707 self .validator .validate (self .model_parts , self .step )
710708
711709 # signal the profiler that the next profiling step has started
0 commit comments