From 76a73c4bff4d9594bd04a18b1e5b76af9b698bbe Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 22 Aug 2025 15:19:40 -0700 Subject: [PATCH 1/5] debug fsdp uunven sharding load checkpoint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/models/llama3/__init__.py | 2 +- torchtitan/models/llama3/train_configs/debug_model.toml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index a34b4463f..558e13108 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,7 +29,7 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index ecabf6e5d..c57c73905 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -42,7 +42,7 @@ min_lr_factor = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 30 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -55,9 +55,9 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" -interval = 10 +interval = 15 last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] From 5b665f59e73c9f3f19ae40ff8dc77f130f339412 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 22 Aug 2025 15:22:44 -0700 Subject: [PATCH 2/5] add repro command Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- repro.sh | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 repro.sh diff --git a/repro.sh b/repro.sh new file mode 100644 index 000000000..46d13d0c8 --- /dev/null +++ b/repro.sh @@ -0,0 +1,2 @@ +NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh +rm -rf outputs/checkpoint/step-30 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh From b71efeb77ecbe9263efd078fa1cc43e6fdad10a5 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 22 Aug 2025 17:19:26 -0700 Subject: [PATCH 3/5] same model ref_model Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/checkpoint.py | 3 + torchtitan/models/attention.py | 4 +- torchtitan/train.py | 126 ++++++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 11 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fcec60185..98805fca9 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -425,6 +425,8 @@ def dcp_load( # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. if MODEL in self.states: + # import fbvscode + # fbvscode.set_trace() self.states[MODEL].load_state_dict(state_dict) @torch.no_grad() @@ -522,6 +524,7 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + # torch.distributed.breakpoint() if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d..f141b3972 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -205,8 +205,8 @@ def _init_backend(cls) -> None: # Add CuDNN on B200 w/ highest priority cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, + # SDPBackend.FLASH_ATTENTION, + # SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] if has_cuda_capability(10, 0): diff --git a/torchtitan/train.py b/torchtitan/train.py index 758a5a699..a8016f256 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -154,8 +157,11 @@ def __init__(self, job_config: JobConfig): logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): + with torch.device("cuda"): + # import fbvscode + # fbvscode.set_trace() model = self.train_spec.model_cls(model_args) + # model = torch.nn.Linear(1024, 1024, device="cuda") # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -257,15 +263,19 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + import copy + model.init_weights() + ref_model = copy.deepcopy(model) + ref_model = self.train_spec.parallelize_fn(ref_model, parallel_dims, job_config) + ref_model.train() + self.ref_model_parts = [ref_model] - model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.train() - self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation @@ -294,6 +304,19 @@ def __init__(self, job_config: JobConfig): self.model_parts ) ) + + self.ref_optimizers = self.train_spec.build_optimizers_fn( + self.ref_model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.ref_lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.ref_optimizers, job_config.lr_scheduler, job_config.training.steps + ) + self.ref_optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.ref_model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers self.metrics_processor.model_parts = self.model_parts @@ -320,6 +343,24 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + self.ref_checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.ref_model_parts, + optimizers=self.ref_optimizers, + lr_schedulers=self.ref_lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) @@ -460,16 +501,32 @@ def forward_backward_step( with self.maybe_enable_amp: pred = model_parts[0](inputs) loss = self.loss_fn(pred, labels) + + import copy + ref_inputs = copy.deepcopy(inputs) + ref_pred = self.ref_model_parts[0](ref_inputs) + ref_loss = self.loss_fn(ref_pred, labels) + + try: + assert torch.equal(pred, ref_pred) + assert torch.equal(loss, ref_loss) + except: + import fbvscode + fbvscode.set_trace() + # need to free to before bwd to avoid peaking memory del pred + del ref_pred loss.backward() + ref_loss.backward() - return loss + return loss, ref_loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + self.ref_optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -478,25 +535,72 @@ def train_step( parallel_dims = self.parallel_dims accumulated_losses = [] + ref_accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): + # import fbvscode + # fbvscode.set_trace() + for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + input_dict, labels = next(data_iterator) - loss = self.forward_backward_step(input_dict, labels) + loss, ref_loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + ref_accumulated_losses.append(ref_loss.detach()) + + for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + full_param_grad = param.grad.full_tensor() + ref_full_param_grad = ref_param.grad.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + try: + assert torch.equal(full_param_grad, ref_full_param_grad) + except: + import fbvscode + fbvscode.set_trace() + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, - foreach=True, + foreach=False, pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), ep_enabled=parallel_dims.ep_enabled, ) + ref_grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.ref_model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=False, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_enabled=parallel_dims.ep_enabled, + ) + try: + assert torch.equal(grad_norm, ref_grad_norm) + except: + import fbvscode + fbvscode.set_trace() self.checkpointer.maybe_wait_for_staging() + self.ref_checkpointer.maybe_wait_for_staging() self.optimizers.step() self.lr_schedulers.step() + self.ref_optimizers.step() + self.ref_lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -540,6 +644,10 @@ def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) + self.ref_checkpointer.load(step=job_config.checkpoint.load_step) + + # import fbvscode + # fbvscode.set_trace() logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( From 728d3e23d0f2758dd360a45ea10eb95d6dc69cb9 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Mon, 25 Aug 2025 15:02:25 -0700 Subject: [PATCH 4/5] repro numeric diffrences before/after dcp load Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/models/llama3/__init__.py | 3 ++- torchtitan/train.py | 26 ++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 558e13108..3cf92a022 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,7 +29,8 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + # dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/train.py b/torchtitan/train.py index a8016f256..90eb33892 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -541,7 +541,7 @@ def train_step( for microbatch in range(self.gradient_accumulation_steps): # import fbvscode # fbvscode.set_trace() - for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() try: @@ -555,12 +555,13 @@ def train_step( accumulated_losses.append(loss.detach()) ref_accumulated_losses.append(ref_loss.detach()) - for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() full_param_grad = param.grad.full_tensor() ref_full_param_grad = ref_param.grad.full_tensor() try: + assert torch.equal(full_param, ref_full_param) except: import fbvscode @@ -643,8 +644,8 @@ def train_step( def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) - self.ref_checkpointer.load(step=job_config.checkpoint.load_step) + # self.checkpointer.load(step=job_config.checkpoint.load_step) + # self.ref_checkpointer.load(step=job_config.checkpoint.load_step) # import fbvscode # fbvscode.set_trace() @@ -698,6 +699,23 @@ def train(self): self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + self.checkpointer.load(step=job_config.checkpoint.load_step) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() # Run validation if validator is available if ( From 5429c430d05b52c23b619e2dd886f5b789bdc53b Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 26 Aug 2025 15:39:34 -0700 Subject: [PATCH 5/5] repro at checkpoint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/checkpoint.py | 22 +++++++++++++++++++++- torchtitan/train.py | 19 +++++++++++-------- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 98805fca9..f7f836780 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -397,6 +397,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + step: int = -1, ) -> None: """Load the checkpoint with dcp. Args: @@ -420,7 +421,15 @@ def dcp_load( state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) else: + before_load = state_dict['tok_embeddings.weight']._local_tensor dcp.load(state_dict, checkpoint_id=checkpoint_id) + after_load = state_dict['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + # dcp.load(state_dict, checkpoint_id=checkpoint_id) + except: + import fbvscode + fbvscode.set_trace() # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. @@ -490,6 +499,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: enable_garbage_collection=True, ) self._purge_stale_checkpoints() + # import fbvscode + # fbvscode.set_trace() + self.load(step=-1) logger.info( "Finished saving the checkpoint (or staging if async is enabled)" @@ -579,11 +591,18 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) + before_load = states['tok_embeddings.weight']._local_tensor self.dcp_load( states, checkpoint_id=checkpoint_id, from_hf=from_hf, ) + after_load = states['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + except: + import fbvscode + fbvscode.set_trace() GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -702,7 +721,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - states_to_load = self._flattened_model_states_sd(states_to_load) + # states_to_load = self._flattened_model_states_sd(states_to_load) + states_to_load = self._flattened_model_states_sd() if self.ft_manager: states_to_load.pop(DATALOADER) diff --git a/torchtitan/train.py b/torchtitan/train.py index 90eb33892..841784469 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -699,19 +699,22 @@ def train(self): self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) - for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - full_param = param.full_tensor() - ref_full_param = ref_param.full_tensor() - try: - assert torch.equal(full_param, ref_full_param) - except: - import fbvscode - fbvscode.set_trace() + # for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + # full_param = param.full_tensor() + # ref_full_param = ref_param.full_tensor() + # try: + # assert torch.equal(full_param, ref_full_param) + # except: + # import fbvscode + # fbvscode.set_trace() self.checkpointer.load(step=job_config.checkpoint.load_step) for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor try: + assert torch.equal(local_param, ref_local_param) assert torch.equal(full_param, ref_full_param) except: import fbvscode