diff --git a/repro.sh b/repro.sh new file mode 100644 index 000000000..46d13d0c8 --- /dev/null +++ b/repro.sh @@ -0,0 +1,2 @@ +NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh +rm -rf outputs/checkpoint/step-30 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fcec60185..f7f836780 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -397,6 +397,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + step: int = -1, ) -> None: """Load the checkpoint with dcp. Args: @@ -420,11 +421,21 @@ def dcp_load( state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) else: + before_load = state_dict['tok_embeddings.weight']._local_tensor dcp.load(state_dict, checkpoint_id=checkpoint_id) + after_load = state_dict['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + # dcp.load(state_dict, checkpoint_id=checkpoint_id) + except: + import fbvscode + fbvscode.set_trace() # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. if MODEL in self.states: + # import fbvscode + # fbvscode.set_trace() self.states[MODEL].load_state_dict(state_dict) @torch.no_grad() @@ -488,6 +499,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: enable_garbage_collection=True, ) self._purge_stale_checkpoints() + # import fbvscode + # fbvscode.set_trace() + self.load(step=-1) logger.info( "Finished saving the checkpoint (or staging if async is enabled)" @@ -522,6 +536,7 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + # torch.distributed.breakpoint() if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf @@ -576,11 +591,18 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) + before_load = states['tok_embeddings.weight']._local_tensor self.dcp_load( states, checkpoint_id=checkpoint_id, from_hf=from_hf, ) + after_load = states['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + except: + import fbvscode + fbvscode.set_trace() GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -699,7 +721,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - states_to_load = self._flattened_model_states_sd(states_to_load) + # states_to_load = self._flattened_model_states_sd(states_to_load) + states_to_load = self._flattened_model_states_sd() if self.ft_manager: states_to_load.pop(DATALOADER) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d..f141b3972 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -205,8 +205,8 @@ def _init_backend(cls) -> None: # Add CuDNN on B200 w/ highest priority cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, + # SDPBackend.FLASH_ATTENTION, + # SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] if has_cuda_capability(10, 0): diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index a34b4463f..3cf92a022 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -30,6 +30,7 @@ llama3_configs = { "debugmodel": TransformerModelArgs( dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + # dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index ecabf6e5d..c57c73905 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -42,7 +42,7 @@ min_lr_factor = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 30 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -55,9 +55,9 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" -interval = 10 +interval = 15 last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/train.py b/torchtitan/train.py index 758a5a699..841784469 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -154,8 +157,11 @@ def __init__(self, job_config: JobConfig): logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): + with torch.device("cuda"): + # import fbvscode + # fbvscode.set_trace() model = self.train_spec.model_cls(model_args) + # model = torch.nn.Linear(1024, 1024, device="cuda") # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -257,15 +263,19 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + import copy + model.init_weights() + ref_model = copy.deepcopy(model) + ref_model = self.train_spec.parallelize_fn(ref_model, parallel_dims, job_config) + ref_model.train() + self.ref_model_parts = [ref_model] - model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.train() - self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation @@ -294,6 +304,19 @@ def __init__(self, job_config: JobConfig): self.model_parts ) ) + + self.ref_optimizers = self.train_spec.build_optimizers_fn( + self.ref_model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.ref_lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.ref_optimizers, job_config.lr_scheduler, job_config.training.steps + ) + self.ref_optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.ref_model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers self.metrics_processor.model_parts = self.model_parts @@ -320,6 +343,24 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + self.ref_checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.ref_model_parts, + optimizers=self.ref_optimizers, + lr_schedulers=self.ref_lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) @@ -460,16 +501,32 @@ def forward_backward_step( with self.maybe_enable_amp: pred = model_parts[0](inputs) loss = self.loss_fn(pred, labels) + + import copy + ref_inputs = copy.deepcopy(inputs) + ref_pred = self.ref_model_parts[0](ref_inputs) + ref_loss = self.loss_fn(ref_pred, labels) + + try: + assert torch.equal(pred, ref_pred) + assert torch.equal(loss, ref_loss) + except: + import fbvscode + fbvscode.set_trace() + # need to free to before bwd to avoid peaking memory del pred + del ref_pred loss.backward() + ref_loss.backward() - return loss + return loss, ref_loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + self.ref_optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -478,25 +535,73 @@ def train_step( parallel_dims = self.parallel_dims accumulated_losses = [] + ref_accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): + # import fbvscode + # fbvscode.set_trace() + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + input_dict, labels = next(data_iterator) - loss = self.forward_backward_step(input_dict, labels) + loss, ref_loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + ref_accumulated_losses.append(ref_loss.detach()) + + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + full_param_grad = param.grad.full_tensor() + ref_full_param_grad = ref_param.grad.full_tensor() + try: + + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + try: + assert torch.equal(full_param_grad, ref_full_param_grad) + except: + import fbvscode + fbvscode.set_trace() + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, - foreach=True, + foreach=False, pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), ep_enabled=parallel_dims.ep_enabled, ) + ref_grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.ref_model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=False, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_enabled=parallel_dims.ep_enabled, + ) + try: + assert torch.equal(grad_norm, ref_grad_norm) + except: + import fbvscode + fbvscode.set_trace() self.checkpointer.maybe_wait_for_staging() + self.ref_checkpointer.maybe_wait_for_staging() self.optimizers.step() self.lr_schedulers.step() + self.ref_optimizers.step() + self.ref_lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -539,7 +644,11 @@ def train_step( def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) + # self.checkpointer.load(step=job_config.checkpoint.load_step) + # self.ref_checkpointer.load(step=job_config.checkpoint.load_step) + + # import fbvscode + # fbvscode.set_trace() logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( @@ -590,6 +699,26 @@ def train(self): self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) + # for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + # full_param = param.full_tensor() + # ref_full_param = ref_param.full_tensor() + # try: + # assert torch.equal(full_param, ref_full_param) + # except: + # import fbvscode + # fbvscode.set_trace() + self.checkpointer.load(step=job_config.checkpoint.load_step) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + assert torch.equal(local_param, ref_local_param) + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() # Run validation if validator is available if (