From 81f698a14522a6b4c9a48f1c744fbea79e75eb37 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 22 Aug 2025 18:24:34 -0700 Subject: [PATCH 01/10] add to_hf --- torchtitan/models/deepseek_v3/__init__.py | 2 +- .../deepseek_v3/model/state_dict_adapter.py | 162 +++++++++++++++++- .../train_configs/debug_model.toml | 6 +- .../train_configs/deepseek_v3_671b.toml | 12 +- .../models/llama3/model/state_dict_adapter.py | 5 +- torchtitan/protocols/state_dict_adapter.py | 6 +- torchtitan/train.py | 15 +- 7 files changed, 184 insertions(+), 24 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 1c3d2b19d..be17ecf5e 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -135,7 +135,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 0bdf456ef..95526b0d1 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -4,24 +4,36 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from logging import raiseExceptions import re -from typing import Any +from typing import Any, Dict import torch +from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import DeepSeekV3ModelArgs from .quantization import calculate_scale_shape, dequantize_from_fp8 +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Shard, + Replicate +) + +from torch.distributed.tensor import DTensor + class DeepSeekV3StateDictAdapter(StateDictAdapter): """ StateDictAdapter for DeepSeekV3 model. """ - def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None): + def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): + super().__init__(model_args, hf_assets_path, parallel_dims) self.model_args = model_args + self.parallel_dims = parallel_dims self.from_hf_map = { "model.embed_tokens.weight": "tok_embeddings.weight", # Attention Module @@ -52,7 +64,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None): "lm_head.weight": "output.weight", } - def _split_experts_weights( + def _split_experts_weight( self, weight: torch.Tensor, n_experts: int ) -> list[torch.Tensor]: """ @@ -84,6 +96,134 @@ def _concatenate_expert_weights( return None + def _get_local_experts_weights( + self, abstract_key: str, layer_id: str, grouped_expert_weight: torch.Tensor + ) -> Dict[str, torch.Tensor]: + """ + Spliting the GroupedExperts weight and find the corresponding individual expert's weight in local tensor. + + Potential experts weights shard placements: + - FSDP + EP when dp_mod_ep * ep <= num_experts: + - StridedShard(0)Shard(0) + - FSDP + EP when dp_mod_ep * ep <= num_experts: + - Shard(1)Shard(0) + - FSDP + ETP + EP when dp_mod_ep * ep <= num_experts: + - w1/w3: StridedShard(0)Shard(0)Shard(1) + - w2: StridedShard(0)Shard(0)Shard(2) + - FSDP + ETP + EP when dp_mod_ep * ep > num_experts: + - w1/w3: StridedShard(1)Shard(0)Shard(1) + - w2: Shard(1)Shard(0)Shard(2) + """ + world_mesh = self.parallel_dims.world_mesh + num_experts = grouped_expert_weight.shape[0] + + # Matching DTensor sharding placement and device mesh dims, + # find the dtensor dims that shard on dim-0 (num_experts dim) + original_placements = grouped_expert_weight.placements + world_mesh_names = [] + dim_0_placements = [] + for i, name in enumerate(world_mesh.mesh_dim_names): + placement = original_placements[i] + if placement.dim == 0: + world_mesh_names.append(name) + dim_0_placements.append(placement) + + start_index, end_index = None, None + # StridedShard(0)Shard(0) + if len(dim_0_placements) == 2: + assert isinstance(dim_0_placements[0], _StridedShard) + strided_shard_mesh = world_mesh[world_mesh_names[0]] + strided_degree, strided_rank = strided_shard_mesh.size(), strided_shard_mesh.get_local_rank() + shard_mesh = world_mesh[world_mesh_names[1]] + shard_degree, shard_rank = shard_mesh.size(), shard_mesh.get_local_rank() + start_index, end_index = self._get_strided_shard_shard_slice(strided_degree, strided_rank, shard_degree, shard_rank, num_experts) + # Shard(0) + elif len(dim_0_placements) == 1: + assert not isinstance(dim_0_placements[0], _StridedShard) + shard_mesh = world_mesh[world_mesh_names[0]] + shard_degree, shard_rank = shard_mesh.size(), shard_mesh.get_local_rank() + block_size = num_experts // shard_degree + if block_size * shard_degree != num_experts: + raise ValueError("Not supported. num_experts can not be evenly divided by Shard(0) dimension degree.") + + start_index = block_size * shard_rank + end_index = start_index + block_size + else: + raise NotImplementedError(f"The DTensor placements {original_placements} for GroupedExperts is not supported in StateDictAdapter") + + # Calculate the new placement for individual expert weights + new_placements = [] + for i, name in enumerate(world_mesh.mesh_dim_names): + placement = original_placements[i] + if placement.dim == 0: + new_placements.append(Replicate()) + elif isinstance(placement, Shard): + # Individual expert weight has only 2 dimensions + new_placements.append(Shard(placement.dim-1)) + elif isinstance(placement, _StridedShard): + new_placements.append(_StridedShard(placement.dim-1, placement.split_factor)) + else: + raise ValueError("Not supported new placements!") + print(f"Original placements: {original_placements}, new placements {new_placements}") + + assert isinstance(grouped_expert_weight, DTensor), "GroupedExperts weight is not a DTensor" + local_grouped_weights = grouped_expert_weight._local_tensor + assert local_grouped_weights.shape[0] == int(end_index - start_index), "Local tensor shape mismatch!" + + # Create new DTensor for each individual expert weights + local_expert_fqn = {} + for expert_id in range(start_index, end_index): + new_key = abstract_key.format(layer_id, expert_id) + new_value = local_grouped_weights[expert_id - start_index, :, :].squeeze + local_expert_fqn[new_key] = DTensor.from_local(new_value, world_mesh, new_placements, run_check=False) + + return local_expert_fqn + + + def _get_strided_shard_shard_slice( + self, + strided_shard_dim_degree: int, + strided_shard_dim_rank: int, + shard_dim_degree: int, + shard_dim_rank: int, + dim_size_to_split: int, + ) -> tuple[int, int]: + """ + Given a [StridedShard(dim=i), Shard(dim=i)] placement, caculate the start index + and end index on dim-i for GPU rank (strided_shard_dim_degree, shard_dim_rank) + + GPU Layout (strided_shard_rank, shard_rank): + + StridedShard Rank Shard rank + ┌─────────────────┐ + 0 │ GPU(0, 0) │ 0 + ────┼─────────────────┤ + 1 │ GPU(1, 0) │ + ────┼─────────────────┤ + 2 │ GPU(2, 0) │ + ──────┼─────────────────┼──── + 0 │ GPU(0, 1) │ 1 + ────┼─────────────────┤ + 1 │ GPU(1, 1) │ + ────┼─────────────────┤ + 2 │ GPU(2, 1) │ + └─────────────────┘ + + Calulate the start_index from inner dimesion (Shard(dim=i)) to outer demension (StridedShard(dim=i)). + """ + + block_size = dim_size_to_split // (strided_shard_dim_degree * shard_dim_degree) + + # Error out if can not evenly divded + if block_size * (strided_shard_dim_degree * shard_dim_degree) != dim_size_to_split: + raise ValueError(f"Not supported split for strided_shard_dim_degree {strided_shard_dim_degree}, shard_dim_degree {shard_dim_degree}, dim_size_to_split {dim_size_to_split}") + + start_index = block_size * (strided_shard_dim_degree * shard_dim_rank + strided_shard_dim_rank) + end_index = start_index + block_size + + return start_index, end_index + + def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ Dequantize the weights from float8 to float32. @@ -149,14 +289,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] - # Split expert weights into separate expert weights - split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + # # Split expert weights into separate expert weights + # split_values = self._split_experts_weights( + # value, self.model_args.moe_args.num_experts + # ) + local_expert_fqn = self._get_local_experts_weights( + new_abstract_key, layer_num, value ) + print(f"groupedWeight placements {value.placements}, local experts keys {local_expert_fqn.keys()}") - for expert_num in range(0, self.model_args.moe_args.num_experts): - new_key = new_abstract_key.format(layer_num, expert_num) - hf_state_dict[new_key] = split_values[expert_num].squeeze() + hf_state_dict.update(local_expert_fqn) elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) @@ -169,9 +311,11 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value + # Prepare for dequantization hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( hf_state_dict ) + print(f"[to_hf] state_dict keys before return: {hf_state_dict_with_scale_inv.keys()}") return hf_state_dict_with_scale_inv def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index dc9f37f44..35e7e5108 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -47,13 +47,13 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 4 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "1F1B" context_parallel_degree = 1 -expert_parallel_degree = 1 -expert_tensor_parallel_degree = 1 +expert_parallel_degree = 2 +expert_tensor_parallel_degree = 4 [checkpoint] enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index ad238839a..6e01dc477 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -38,7 +38,7 @@ min_lr_factor = 0.1 local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 10_000 +steps = 10 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) @@ -46,20 +46,22 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 1 -expert_tensor_parallel_degree = 1 +expert_parallel_degree = 2 +expert_tensor_parallel_degree = 2 [checkpoint] enable = false folder = "checkpoint" -interval = 500 +interval = 10 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base" +initial_load_in_hf=true [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index cae0b4c17..14e79bc4e 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -10,14 +10,15 @@ logger = logging.getLogger() +from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs class Llama3StateDictAdapter(StateDictAdapter): - def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None): - super().__init__(model_args, hf_assets_path) + def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): + super().__init__(model_args, hf_assets_path, parallel_dims) self.model_args = model_args self.hf_assets_path = hf_assets_path diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 916368cb9..8684a6cd8 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -11,6 +11,8 @@ from abc import ABC, abstractmethod from typing import Any +from torchtitan.distributed.parallel_dims import ParallelDims + logger = logging.getLogger() from .model import BaseModelArgs @@ -27,7 +29,7 @@ class BaseStateDictAdapter(ABC): """ @abstractmethod - def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): + def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): pass @abstractmethod @@ -58,7 +60,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" - def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): + def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): if hf_assets_path: mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json") try: diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b69fd679..9e6d63334 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -23,7 +23,7 @@ ensure_pp_loss_visible, ) from torchtitan.config import ConfigManager, JobConfig -from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed import ParallelDims, parallel_dims, utils as dist_utils from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils @@ -311,7 +311,7 @@ def __init__(self, job_config: JobConfig): checkpoint_config=job_config.checkpoint, sd_adapter=( self.train_spec.state_dict_adapter( - model_args, job_config.model.hf_assets_path + model_args, job_config.model.hf_assets_path, self.parallel_dims ) if self.train_spec.state_dict_adapter else None @@ -539,6 +539,17 @@ def train_step( def train(self): job_config = self.job_config + # Following hacky print only works for debug_model + # w1 = self.model_parts[0].layers["1"].moe.experts.w1 + # w2 = self.model_parts[0].layers["1"].moe.experts.w2 + # w3 = self.model_parts[0].layers["1"].moe.experts.w3 + + # logger.info(f"w1 placements is: {w1.placements}, {type(w1.placements)}") + # logger.info(f"w2 placements is: {w2.placements}") + # logger.info(f"w3 placements is: {w3.placements}") + # logger.info(f"device mesh: {self.parallel_dims.world_mesh}, {self.parallel_dims.world_mesh.mesh_dim_names} {self.parallel_dims.world_mesh['dp_shard']}") + + self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}") From e7c39f60d05aaa10983042c30ffdcbc270fae49a Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 25 Aug 2025 16:47:21 -0700 Subject: [PATCH 02/10] debugging --- torchtitan/components/checkpoint.py | 10 + .../models/deepseek_v3/infra/parallelize.py | 1 + .../deepseek_v3/model/state_dict_adapter.py | 476 ++++++++++++------ .../models/llama3/model/state_dict_adapter.py | 7 +- torchtitan/protocols/state_dict_adapter.py | 14 +- torchtitan/train.py | 10 +- 6 files changed, 352 insertions(+), 166 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fcec60185..70e20913d 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -418,6 +418,16 @@ def dcp_load( ) state_dict = self.sd_adapter.from_hf(hf_state_dict) + + # [rank0]:after sd converter, placement is DeviceMesh((dp_shard_mod_ep=2, dp_shard_in_ep=2, tp=2), device: 'cuda', stride: (4, 2, 1)) + print( + f"after sd converter, placement is {state_dict['layers.3.moe.experts.w3'].device_mesh}, type {type(state_dict['layers.3.moe.experts.w3'])}, placement {state_dict['layers.3.moe.experts.w3'].placements}" + ) + + # [rank0]:after sd converter, model placement is DeviceMesh((dp_shard_mod_ep=2, ep=2, tp=2), device: 'cuda', stride: (4, 2, 1)) + # model_state_dict = self.states[MODEL].state_dict() + # print(f"after sd converter, model placement is {model_state_dict['layers.3.moe.experts.w3'].device_mesh}") + self.states[MODEL].load_state_dict(state_dict) else: dcp.load(state_dict, checkpoint_id=checkpoint_id) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c77250d0f..b2ef2790c 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -36,6 +36,7 @@ def parallelize_deepseekv3( job_config: JobConfig, ): world_mesh = parallel_dims.world_mesh + print(f"In parallelize_deepseekv3, world mesh is {world_mesh}") # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 95526b0d1..8c6ca060a 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -4,36 +4,37 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from logging import raiseExceptions + import re from typing import Any, Dict import torch +from torch.distributed.device_mesh import DeviceMesh + +from torch.distributed.tensor import DTensor + +from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.tools.logging import logger from .args import DeepSeekV3ModelArgs from .quantization import calculate_scale_shape, dequantize_from_fp8 -from torch.distributed.tensor.placement_types import ( - _StridedShard, - Shard, - Replicate -) - -from torch.distributed.tensor import DTensor - class DeepSeekV3StateDictAdapter(StateDictAdapter): """ StateDictAdapter for DeepSeekV3 model. """ - def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): - super().__init__(model_args, hf_assets_path, parallel_dims) + def __init__( + self, + model_args: DeepSeekV3ModelArgs, + hf_assets_path: str | None, + ): + super().__init__(model_args, hf_assets_path) self.model_args = model_args - self.parallel_dims = parallel_dims self.from_hf_map = { "model.embed_tokens.weight": "tok_embeddings.weight", # Attention Module @@ -64,165 +65,311 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None, "lm_head.weight": "output.weight", } - def _split_experts_weight( - self, weight: torch.Tensor, n_experts: int - ) -> list[torch.Tensor]: - """ - Split the weights of the experts into a list of tensors. - """ - split_weight = torch.split(weight, weight.shape[0] // n_experts, dim=0) - return split_weight + # Store metadata for GroupedExperts <-> individual experts conversion + self.grouped_expert_weight_placements = {} # {titan_abstract_key: placements} + self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} + self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} - def _concatenate_expert_weights( - self, expert_weights_by_layer: dict[str, Any], n_experts: int - ) -> torch.Tensor: + def _calculate_strided_shard_shard_indices( + self, + strided_shard_dim_degree: int, + strided_shard_dim_rank: int, + shard_dim_degree: int, + shard_dim_rank: int, + dim_size_to_split: int, + ) -> tuple[int, int]: """ - Concatenate the weights of separate experts into GroupedExpert weights. + Given a [StridedShard(dim=i), Shard(dim=i)] placement, caculate the start index + and end index on dim-i for GPU rank (strided_shard_dim_degree, shard_dim_rank) + + GPU Layout (strided_shard_rank, shard_rank): + + StridedShard Rank Shard rank + ┌─────────────────┐ + 0 │ GPU(0, 0) │ 0 + ────┼─────────────────┤ + 1 │ GPU(1, 0) │ + ────┼─────────────────┤ + 2 │ GPU(2, 0) │ + ──────┼─────────────────┼──── + 0 │ GPU(0, 1) │ 1 + ────┼─────────────────┤ + 1 │ GPU(1, 1) │ + ────┼─────────────────┤ + 2 │ GPU(2, 1) │ + └─────────────────┘ + + Calulate the start_index from inner dimesion (Shard(dim=i)) to outer demension (StridedShard(dim=i)). """ - for layer, abstract_keys in list(expert_weights_by_layer.items()): - for abstract_key, experts in list(abstract_keys.items()): - # If we have all the experts for this abstract_key, concatenate them - if len(experts) == n_experts: - sorted_expert_ids = sorted(experts.keys()) - sorted_experts = [experts[i] for i in sorted_expert_ids] - stacked_tensor = torch.stack(sorted_experts, dim=0) - # Remove these experts from the tracking dict to free memory - del expert_weights_by_layer[layer][abstract_key] - if not expert_weights_by_layer[layer]: - del expert_weights_by_layer[layer] + block_size = dim_size_to_split // (strided_shard_dim_degree * shard_dim_degree) - return stacked_tensor + # Error out if can not evenly divded + if ( + block_size * (strided_shard_dim_degree * shard_dim_degree) + != dim_size_to_split + ): + raise ValueError( + f"Not supported split for strided_shard_dim_degree {strided_shard_dim_degree}, shard_dim_degree {shard_dim_degree}, dim_size_to_split {dim_size_to_split}" + ) + + start_index = block_size * ( + strided_shard_dim_degree * shard_dim_rank + strided_shard_dim_rank + ) + end_index = start_index + block_size - return None + return start_index, end_index + + def _caculate_indices_from_placements( + self, + dim: int, + dim_size: int, + dtensor_placements: tuple, + device_mesh: DeviceMesh, + ): + + mesh_names = [] + dim_i_placements = [] + + # Find all the device mesh dimensios that shard on dim-i + for i, name in enumerate(device_mesh.mesh_dim_names): + placement = dtensor_placements[i] + print( + f"In _caculate_indices_from_placements, placement dim = {placement.dim} {type(placement.dim)}, {dim} {type(dim)}" + ) + if placement.dim == dim: + mesh_names.append(name) + dim_i_placements.append(placement) + + # Calculate local expert indices based on sharding strategy + start_index, end_index = None, None + if len(dim_i_placements) == 2: + # Handle StridedShard(i) + Shard(i) case + assert isinstance( + dim_i_placements[0], _StridedShard + ), "Expected StridedShard as first placement" + + strided_shard_mesh = device_mesh[mesh_names[0]] + shard_mesh = device_mesh[mesh_names[1]] + + strided_degree = strided_shard_mesh.size() + strided_rank = strided_shard_mesh.get_local_rank() + shard_degree = shard_mesh.size() + shard_rank = shard_mesh.get_local_rank() + + start_index, end_index = self._calculate_strided_shard_shard_indices( + strided_degree, strided_rank, shard_degree, shard_rank, dim_size + ) + + return start_index, end_index + + elif len(dim_i_placements) == 1: + # Handle single Shard(i) case + assert not isinstance( + dim_i_placements[0], _StridedShard + ), "Expected regular Shard, not StridedShard" + + shard_mesh = device_mesh[mesh_names[0]] + shard_degree = shard_mesh.size() + shard_rank = shard_mesh.get_local_rank() + + block_size = dim_size // shard_degree + if block_size * shard_degree != dim_size: + raise ValueError( + f"Dim {dim} size ({dim_size}) cannot be evenly divided by shard degree ({shard_degree})" + ) + + start_index = block_size * shard_rank + end_index = start_index + block_size + + return start_index, end_index + + elif len(dim_i_placements) == 0: + # No need to split on this dimension + return start_index, end_index + + else: + raise NotImplementedError( + f"Unsupported DTensor placements for GroupedExperts: {dtensor_placements} {dim_i_placements} {mesh_names}" + ) def _get_local_experts_weights( - self, abstract_key: str, layer_id: str, grouped_expert_weight: torch.Tensor + self, + abstract_key: str, + titan_abstract_key: str, + layer_id: str, + grouped_expert_weight: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ - Spliting the GroupedExperts weight and find the corresponding individual expert's weight in local tensor. - - Potential experts weights shard placements: - - FSDP + EP when dp_mod_ep * ep <= num_experts: - - StridedShard(0)Shard(0) - - FSDP + EP when dp_mod_ep * ep <= num_experts: - - Shard(1)Shard(0) - - FSDP + ETP + EP when dp_mod_ep * ep <= num_experts: - - w1/w3: StridedShard(0)Shard(0)Shard(1) - - w2: StridedShard(0)Shard(0)Shard(2) - - FSDP + ETP + EP when dp_mod_ep * ep > num_experts: - - w1/w3: StridedShard(1)Shard(0)Shard(1) - - w2: Shard(1)Shard(0)Shard(2) + Split GroupedExperts weight into individual expert weights for local processing. + + This method handles various sharding strategies for expert weights: + - FSDP + EP: StridedShard(0)Shard(0) or Shard(0) + - FSDP + ETP + EP: StridedShard(0)Shard(0)Shard(1/2) or StridedShard(1)Shard(0)Shard(1/2) + + Args: + abstract_key: HuggingFace templage key with {} placeholders for layer and expert IDs + titan_abstract_key: TorchTitan templage key with {} placeholders for layer and expert IDs + layer_id: Layer identifier + grouped_expert_weight: DTensor containing all experts' weights + + Returns: + Dictionary mapping individual expert keys to their DTensor weights """ - world_mesh = self.parallel_dims.world_mesh + device_mesh = grouped_expert_weight.device_mesh + dtensor_placements = grouped_expert_weight.placements + + # Step 1: Extract dimension-0 placement information num_experts = grouped_expert_weight.shape[0] + start_index, end_index = self._caculate_indices_from_placements( + dim=0, + dim_size=num_experts, + dtensor_placements=dtensor_placements, + device_mesh=device_mesh, + ) + assert ( + start_index is not None and end_index is not None + ), "Start index and end index can not be None on dim-0!" - # Matching DTensor sharding placement and device mesh dims, - # find the dtensor dims that shard on dim-0 (num_experts dim) - original_placements = grouped_expert_weight.placements - world_mesh_names = [] - dim_0_placements = [] - for i, name in enumerate(world_mesh.mesh_dim_names): - placement = original_placements[i] - if placement.dim == 0: - world_mesh_names.append(name) - dim_0_placements.append(placement) - - start_index, end_index = None, None - # StridedShard(0)Shard(0) - if len(dim_0_placements) == 2: - assert isinstance(dim_0_placements[0], _StridedShard) - strided_shard_mesh = world_mesh[world_mesh_names[0]] - strided_degree, strided_rank = strided_shard_mesh.size(), strided_shard_mesh.get_local_rank() - shard_mesh = world_mesh[world_mesh_names[1]] - shard_degree, shard_rank = shard_mesh.size(), shard_mesh.get_local_rank() - start_index, end_index = self._get_strided_shard_shard_slice(strided_degree, strided_rank, shard_degree, shard_rank, num_experts) - # Shard(0) - elif len(dim_0_placements) == 1: - assert not isinstance(dim_0_placements[0], _StridedShard) - shard_mesh = world_mesh[world_mesh_names[0]] - shard_degree, shard_rank = shard_mesh.size(), shard_mesh.get_local_rank() - block_size = num_experts // shard_degree - if block_size * shard_degree != num_experts: - raise ValueError("Not supported. num_experts can not be evenly divided by Shard(0) dimension degree.") - - start_index = block_size * shard_rank - end_index = start_index + block_size - else: - raise NotImplementedError(f"The DTensor placements {original_placements} for GroupedExperts is not supported in StateDictAdapter") + # Step 2: Store indices for potential future use in from_hf() + self.local_experts_indices[titan_abstract_key] = (start_index, end_index) - # Calculate the new placement for individual expert weights + # Step 3: Create new placements for individual expert weights new_placements = [] - for i, name in enumerate(world_mesh.mesh_dim_names): - placement = original_placements[i] + for i, name in enumerate(device_mesh.mesh_dim_names): + placement = dtensor_placements[i] if placement.dim == 0: + # Convert dim-0 sharding to replication for individual experts new_placements.append(Replicate()) elif isinstance(placement, Shard): - # Individual expert weight has only 2 dimensions - new_placements.append(Shard(placement.dim-1)) + # Keep other shard dimensions (individual expert weight has 2D) + new_placements.append(Shard(placement.dim)) elif isinstance(placement, _StridedShard): - new_placements.append(_StridedShard(placement.dim-1, placement.split_factor)) + # Keep strided shard with same parameters + new_placements.append( + _StridedShard(placement.dim, placement.split_factor) + ) else: - raise ValueError("Not supported new placements!") - print(f"Original placements: {original_placements}, new placements {new_placements}") - - assert isinstance(grouped_expert_weight, DTensor), "GroupedExperts weight is not a DTensor" + raise ValueError(f"Unsupported placement type: {type(placement)}") + + # Step 4: Create individual expert DTensors + assert isinstance( + grouped_expert_weight, DTensor + ), "Expected DTensor for grouped expert weight" + local_grouped_weights = grouped_expert_weight._local_tensor - assert local_grouped_weights.shape[0] == int(end_index - start_index), "Local tensor shape mismatch!" + expected_local_experts = end_index - start_index - # Create new DTensor for each individual expert weights - local_expert_fqn = {} + if local_grouped_weights.shape[0] != expected_local_experts: + raise ValueError( + f"Local tensor shape mismatch: expected {expected_local_experts} experts, " + f"got {local_grouped_weights.shape[0]}" + ) + + local_expert_tensors = {} for expert_id in range(start_index, end_index): - new_key = abstract_key.format(layer_id, expert_id) - new_value = local_grouped_weights[expert_id - start_index, :, :].squeeze - local_expert_fqn[new_key] = DTensor.from_local(new_value, world_mesh, new_placements, run_check=False) - - return local_expert_fqn - - - def _get_strided_shard_shard_slice( + expert_key = abstract_key.format(layer_id, expert_id) + local_expert_index = expert_id - start_index + + # Extract individual expert weight and add batch dimension temporarily + expert_weight = local_grouped_weights[local_expert_index, :, :].unsqueeze(0) + + # Create DTensor and remove batch dimension (experts dimension is removed) + expert_dtensor = DTensor.from_local( + expert_weight, device_mesh, new_placements, run_check=False + ).squeeze(0) + + local_expert_tensors[expert_key] = expert_dtensor + + return local_expert_tensors + + def _chunk_local_expert_weights( self, - strided_shard_dim_degree: int, - strided_shard_dim_rank: int, - shard_dim_degree: int, - shard_dim_rank: int, - dim_size_to_split: int, - ) -> tuple[int, int]: + local_tensor: torch.Tensor, + dtensor_placements: tuple, + dtensor_shape: tuple, + device_mesh: DeviceMesh, + ): """ - Given a [StridedShard(dim=i), Shard(dim=i)] placement, caculate the start index - and end index on dim-i for GPU rank (strided_shard_dim_degree, shard_dim_rank) - - GPU Layout (strided_shard_rank, shard_rank): + Chunk the local individual experts weight, assemble back to GroupedExperts weights DTensor. - StridedShard Rank Shard rank - ┌─────────────────┐ - 0 │ GPU(0, 0) │ 0 - ────┼─────────────────┤ - 1 │ GPU(1, 0) │ - ────┼─────────────────┤ - 2 │ GPU(2, 0) │ - ──────┼─────────────────┼──── - 0 │ GPU(0, 1) │ 1 - ────┼─────────────────┤ - 1 │ GPU(1, 1) │ - ────┼─────────────────┤ - 2 │ GPU(2, 1) │ - └─────────────────┘ + This method is a placeholder for future implementation of expert weight concatenation. - Calulate the start_index from inner dimesion (Shard(dim=i)) to outer demension (StridedShard(dim=i)). + Args: + local_tensor: Concatenated local individual expert weights """ - block_size = dim_size_to_split // (strided_shard_dim_degree * shard_dim_degree) - - # Error out if can not evenly divded - if block_size * (strided_shard_dim_degree * shard_dim_degree) != dim_size_to_split: - raise ValueError(f"Not supported split for strided_shard_dim_degree {strided_shard_dim_degree}, shard_dim_degree {shard_dim_degree}, dim_size_to_split {dim_size_to_split}") + # Calculate the index range on dim-i to chunk + for i in range(1, len(dtensor_placements)): + dim_size = dtensor_shape[i] + start_index, end_index = self._caculate_indices_from_placements( + dim=i, + dim_size=dim_size, + dtensor_placements=dtensor_placements, + device_mesh=device_mesh, + ) + # No need to chunk on current dimension + if start_index is None or end_index is None: + continue + + # Chunk local_tensor on dim-i + local_tensor = local_tensor.narrow(i, start_index, end_index - start_index) + + # Assemble DTensor + grouped_expert_weights = DTensor.from_local( + local_tensor, device_mesh, dtensor_placements, run_check=False + ) - start_index = block_size * (strided_shard_dim_degree * shard_dim_rank + strided_shard_dim_rank) - end_index = start_index + block_size + return grouped_expert_weights - return start_index, end_index + def _concatenate_local_expert_weights( + self, + expert_weights_by_layer: dict[str, Any], + abstract_key: str, + device_mesh: DeviceMesh, + ) -> torch.Tensor: + """ + Concatenate the weights of separate experts into GroupedExperts weights. + """ + logger.info(f"Concatenating for key {abstract_key} ") + for layer in expert_weights_by_layer.keys(): + # If we have all the experts for this abstract_key, concatenate them + experts = expert_weights_by_layer[layer][abstract_key] + expected_n_experts = ( + self.local_experts_indices[abstract_key][1] + - self.local_experts_indices[abstract_key][0] + ) + if len(experts) == expected_n_experts: + sorted_expert_ids = sorted(experts.keys()) + sorted_experts = [experts[i] for i in sorted_expert_ids] + local_tensor = torch.stack(sorted_experts, dim=0) + + assert ( + abstract_key in self.grouped_expert_weight_placements + and abstract_key in self.grouped_expert_weight_shape + ), f"GroupedExperts weight metadata {self.grouped_expert_weight_placements} {self.grouped_expert_weight_shape} can not be None!" + + stacked_dtensor = self._chunk_local_expert_weights( + local_tensor, + dtensor_placements=self.grouped_expert_weight_placements[ + abstract_key + ], + dtensor_shape=self.grouped_expert_weight_shape[abstract_key], + device_mesh=device_mesh, + ) + # Remove these experts from the tracking dict to free memory + del expert_weights_by_layer[layer][abstract_key] + if not expert_weights_by_layer[layer]: + del expert_weights_by_layer[layer] + + logger.info(f"Concatenated for key {abstract_key} at layer {layer}") + + return stacked_dtensor + else: + logger.info("no enough experts to concate") + + return None def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -233,11 +380,13 @@ def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, weight in state_dict.items(): if key.endswith(".weight") and key + "_scale_inv" in state_dict: scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight + # dequantized_weight = dequantize_from_fp8( + # weight, scale_inv, dtype=torch.float32 + # ) + # # update the weight and remove the scale_inv tensor + # state_dict[key] = dequantized_weight + + state_dict[key] = weight scale_inv_keys.append(key + "_scale_inv") for key in scale_inv_keys: @@ -289,14 +438,17 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] - # # Split expert weights into separate expert weights - # split_values = self._split_experts_weights( - # value, self.model_args.moe_args.num_experts - # ) + # Store the GroupedExperts Weight metadata for from_hf() + self.grouped_expert_weight_placements[abstract_key] = value.placements + self.grouped_expert_weight_shape[abstract_key] = value.shape + + # Split GroupedExperts weight to local individual expert weights local_expert_fqn = self._get_local_experts_weights( - new_abstract_key, layer_num, value + new_abstract_key, + abstract_key, + layer_num, + value, ) - print(f"groupedWeight placements {value.placements}, local experts keys {local_expert_fqn.keys()}") hf_state_dict.update(local_expert_fqn) @@ -315,7 +467,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( hf_state_dict ) - print(f"[to_hf] state_dict keys before return: {hf_state_dict_with_scale_inv.keys()}") return hf_state_dict_with_scale_inv def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: @@ -324,7 +475,11 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 2. Convert between the HF shape and the torchtitan shape. 3. Concate separate expert's wegiht into GroupedExperts' weight. """ + print( + f"At the beginning of from_hf, the loaded state_dict is {hf_state_dict.keys()}" + ) # dequantize the tensor in state_dict and remove the scale_inv tensor + hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} @@ -334,19 +489,24 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: if "mlp.experts" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=2) layer_num, expert_num = re.findall(r"\d+", key) - new_key = self.from_hf_map[abstract_key] - new_key = new_key.format(layer_num) + titan_abstract_key = self.from_hf_map[abstract_key] + new_key = titan_abstract_key.format(layer_num) # Store the expert's weight in expert_weights_by_layer for concatenating later. if layer_num not in expert_weights_by_layer: expert_weights_by_layer[layer_num] = {} - if abstract_key not in expert_weights_by_layer[layer_num]: - expert_weights_by_layer[layer_num][abstract_key] = {} - expert_weights_by_layer[layer_num][abstract_key][expert_num] = value + if titan_abstract_key not in expert_weights_by_layer[layer_num]: + expert_weights_by_layer[layer_num][titan_abstract_key] = {} + expert_weights_by_layer[layer_num][titan_abstract_key][ + expert_num + ] = value # try to concat the expert's weight into GroupedExperts' weight. - stacked_value = self._concatenate_expert_weights( - expert_weights_by_layer, self.model_args.moe_args.num_experts + # stacked_value = self._concatenate_expert_weights( + # expert_weights_by_layer, self.model_args.moe_args.num_experts + # ) + stacked_value = self._concatenate_local_expert_weights( + expert_weights_by_layer, titan_abstract_key, value.device_mesh ) if stacked_value is not None: state_dict[new_key] = stacked_value diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 14e79bc4e..b59b95de9 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -17,7 +17,12 @@ class Llama3StateDictAdapter(StateDictAdapter): - def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): + def __init__( + self, + model_args: TransformerModelArgs, + hf_assets_path: str | None, + parallel_dims: ParallelDims, + ): super().__init__(model_args, hf_assets_path, parallel_dims) self.model_args = model_args diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 8684a6cd8..4221861b1 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -11,8 +11,6 @@ from abc import ABC, abstractmethod from typing import Any -from torchtitan.distributed.parallel_dims import ParallelDims - logger = logging.getLogger() from .model import BaseModelArgs @@ -29,7 +27,11 @@ class BaseStateDictAdapter(ABC): """ @abstractmethod - def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): + def __init__( + self, + model_args: BaseModelArgs, + hf_assets_path: str | None, + ): pass @abstractmethod @@ -60,7 +62,11 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" - def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims): + def __init__( + self, + model_args: BaseModelArgs, + hf_assets_path: str | None, + ): if hf_assets_path: mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json") try: diff --git a/torchtitan/train.py b/torchtitan/train.py index 9e6d63334..ff6f55c41 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -23,7 +23,7 @@ ensure_pp_loss_visible, ) from torchtitan.config import ConfigManager, JobConfig -from torchtitan.distributed import ParallelDims, parallel_dims, utils as dist_utils +from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils @@ -107,6 +107,7 @@ def __init__(self, job_config: JobConfig): ) world_mesh = parallel_dims.world_mesh + print(f"Worldmesh in trainer init : {world_mesh}") if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() @@ -257,6 +258,9 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + print( + f"the world mesh before applying parallelize_fn {parallel_dims.world_mesh}" + ) model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) @@ -311,7 +315,8 @@ def __init__(self, job_config: JobConfig): checkpoint_config=job_config.checkpoint, sd_adapter=( self.train_spec.state_dict_adapter( - model_args, job_config.model.hf_assets_path, self.parallel_dims + model_args, + job_config.model.hf_assets_path, ) if self.train_spec.state_dict_adapter else None @@ -549,7 +554,6 @@ def train(self): # logger.info(f"w3 placements is: {w3.placements}") # logger.info(f"device mesh: {self.parallel_dims.world_mesh}, {self.parallel_dims.world_mesh.mesh_dim_names} {self.parallel_dims.world_mesh['dp_shard']}") - self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}") From 4bd34fbb7c1aab32eddfb21dc9d7394f93b0b05f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 25 Aug 2025 16:52:25 -0700 Subject: [PATCH 03/10] debugging --- .../deepseek_v3/model/state_dict_adapter.py | 21 ++++--------------- .../models/llama3/model/state_dict_adapter.py | 3 +-- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 8c6ca060a..982e29ad1 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -15,9 +15,7 @@ from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard -from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.protocols.state_dict_adapter import StateDictAdapter -from torchtitan.tools.logging import logger from .args import DeepSeekV3ModelArgs from .quantization import calculate_scale_shape, dequantize_from_fp8 @@ -134,9 +132,6 @@ def _caculate_indices_from_placements( # Find all the device mesh dimensios that shard on dim-i for i, name in enumerate(device_mesh.mesh_dim_names): placement = dtensor_placements[i] - print( - f"In _caculate_indices_from_placements, placement dim = {placement.dim} {type(placement.dim)}, {dim} {type(dim)}" - ) if placement.dim == dim: mesh_names.append(name) dim_i_placements.append(placement) @@ -161,8 +156,6 @@ def _caculate_indices_from_placements( strided_degree, strided_rank, shard_degree, shard_rank, dim_size ) - return start_index, end_index - elif len(dim_i_placements) == 1: # Handle single Shard(i) case assert not isinstance( @@ -182,8 +175,6 @@ def _caculate_indices_from_placements( start_index = block_size * shard_rank end_index = start_index + block_size - return start_index, end_index - elif len(dim_i_placements) == 0: # No need to split on this dimension return start_index, end_index @@ -193,6 +184,9 @@ def _caculate_indices_from_placements( f"Unsupported DTensor placements for GroupedExperts: {dtensor_placements} {dim_i_placements} {mesh_names}" ) + return start_index, end_index + + def _get_local_experts_weights( self, abstract_key: str, @@ -331,7 +325,6 @@ def _concatenate_local_expert_weights( """ Concatenate the weights of separate experts into GroupedExperts weights. """ - logger.info(f"Concatenating for key {abstract_key} ") for layer in expert_weights_by_layer.keys(): # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer][abstract_key] @@ -363,11 +356,7 @@ def _concatenate_local_expert_weights( if not expert_weights_by_layer[layer]: del expert_weights_by_layer[layer] - logger.info(f"Concatenated for key {abstract_key} at layer {layer}") - return stacked_dtensor - else: - logger.info("no enough experts to concate") return None @@ -475,9 +464,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 2. Convert between the HF shape and the torchtitan shape. 3. Concate separate expert's wegiht into GroupedExperts' weight. """ - print( - f"At the beginning of from_hf, the loaded state_dict is {hf_state_dict.keys()}" - ) + # dequantize the tensor in state_dict and remove the scale_inv tensor hf_state_dict = self._dequantize(hf_state_dict) diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index b59b95de9..8e631a8af 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -21,9 +21,8 @@ def __init__( self, model_args: TransformerModelArgs, hf_assets_path: str | None, - parallel_dims: ParallelDims, ): - super().__init__(model_args, hf_assets_path, parallel_dims) + super().__init__(model_args, hf_assets_path) self.model_args = model_args self.hf_assets_path = hf_assets_path From 943c0a379d185ed16447029988889a9cdcd53901 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 26 Aug 2025 13:39:43 -0700 Subject: [PATCH 04/10] fix loading error --- torchtitan/components/checkpoint.py | 11 +--- .../deepseek_v3/model/state_dict_adapter.py | 57 +++---------------- 2 files changed, 8 insertions(+), 60 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 70e20913d..ca24420c1 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -418,16 +418,7 @@ def dcp_load( ) state_dict = self.sd_adapter.from_hf(hf_state_dict) - - # [rank0]:after sd converter, placement is DeviceMesh((dp_shard_mod_ep=2, dp_shard_in_ep=2, tp=2), device: 'cuda', stride: (4, 2, 1)) - print( - f"after sd converter, placement is {state_dict['layers.3.moe.experts.w3'].device_mesh}, type {type(state_dict['layers.3.moe.experts.w3'])}, placement {state_dict['layers.3.moe.experts.w3'].placements}" - ) - - # [rank0]:after sd converter, model placement is DeviceMesh((dp_shard_mod_ep=2, ep=2, tp=2), device: 'cuda', stride: (4, 2, 1)) - # model_state_dict = self.states[MODEL].state_dict() - # print(f"after sd converter, model placement is {model_state_dict['layers.3.moe.experts.w3'].device_mesh}") - + self.states[MODEL].load_state_dict(state_dict) else: dcp.load(state_dict, checkpoint_id=checkpoint_id) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 982e29ad1..666583f27 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -6,6 +6,7 @@ import re +from threading import local from typing import Any, Dict import torch @@ -276,46 +277,7 @@ def _get_local_experts_weights( local_expert_tensors[expert_key] = expert_dtensor return local_expert_tensors - - def _chunk_local_expert_weights( - self, - local_tensor: torch.Tensor, - dtensor_placements: tuple, - dtensor_shape: tuple, - device_mesh: DeviceMesh, - ): - """ - Chunk the local individual experts weight, assemble back to GroupedExperts weights DTensor. - - This method is a placeholder for future implementation of expert weight concatenation. - - Args: - local_tensor: Concatenated local individual expert weights - """ - - # Calculate the index range on dim-i to chunk - for i in range(1, len(dtensor_placements)): - dim_size = dtensor_shape[i] - start_index, end_index = self._caculate_indices_from_placements( - dim=i, - dim_size=dim_size, - dtensor_placements=dtensor_placements, - device_mesh=device_mesh, - ) - # No need to chunk on current dimension - if start_index is None or end_index is None: - continue - - # Chunk local_tensor on dim-i - local_tensor = local_tensor.narrow(i, start_index, end_index - start_index) - - # Assemble DTensor - grouped_expert_weights = DTensor.from_local( - local_tensor, device_mesh, dtensor_placements, run_check=False - ) - - return grouped_expert_weights - + def _concatenate_local_expert_weights( self, expert_weights_by_layer: dict[str, Any], @@ -323,7 +285,7 @@ def _concatenate_local_expert_weights( device_mesh: DeviceMesh, ) -> torch.Tensor: """ - Concatenate the weights of separate experts into GroupedExperts weights. + Try to concatenate the weights of separate experts into GroupedExperts weights. """ for layer in expert_weights_by_layer.keys(): # If we have all the experts for this abstract_key, concatenate them @@ -335,20 +297,15 @@ def _concatenate_local_expert_weights( if len(experts) == expected_n_experts: sorted_expert_ids = sorted(experts.keys()) sorted_experts = [experts[i] for i in sorted_expert_ids] - local_tensor = torch.stack(sorted_experts, dim=0) - + local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor + assert ( abstract_key in self.grouped_expert_weight_placements and abstract_key in self.grouped_expert_weight_shape ), f"GroupedExperts weight metadata {self.grouped_expert_weight_placements} {self.grouped_expert_weight_shape} can not be None!" - stacked_dtensor = self._chunk_local_expert_weights( - local_tensor, - dtensor_placements=self.grouped_expert_weight_placements[ - abstract_key - ], - dtensor_shape=self.grouped_expert_weight_shape[abstract_key], - device_mesh=device_mesh, + stacked_dtensor = DTensor.from_local( + local_tensor, device_mesh, self.grouped_expert_weight_placements[abstract_key], run_check=False ) # Remove these experts from the tracking dict to free memory From fb3b3fcb14e7e9016761e37bce1dffb97bf5ab9d Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 26 Aug 2025 21:08:35 -0700 Subject: [PATCH 05/10] fix assemble algo --- .../deepseek_v3/model/state_dict_adapter.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 666583f27..9e2e6f952 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -326,13 +326,13 @@ def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, weight in state_dict.items(): if key.endswith(".weight") and key + "_scale_inv" in state_dict: scale_inv = state_dict[key + "_scale_inv"] - # dequantized_weight = dequantize_from_fp8( - # weight, scale_inv, dtype=torch.float32 - # ) - # # update the weight and remove the scale_inv tensor - # state_dict[key] = dequantized_weight + dequantized_weight = dequantize_from_fp8( + weight, scale_inv, dtype=torch.float32 + ) + # update the weight and remove the scale_inv tensor + state_dict[key] = dequantized_weight - state_dict[key] = weight + # state_dict[key] = weight scale_inv_keys.append(key + "_scale_inv") for key in scale_inv_keys: @@ -452,7 +452,15 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: stacked_value = self._concatenate_local_expert_weights( expert_weights_by_layer, titan_abstract_key, value.device_mesh ) + if stacked_value is not None: + local_tensor = stacked_value._local_tensor + + tensor_list = local_tensor.tolist() + # Save to JSON file + import json + with open(f'my_implementation_tensor_{new_key}.json', 'w') as f: + json.dump(tensor_list, f) state_dict[new_key] = stacked_value elif "layers" in key: From b4d614d452047306dedc51777552f3b3ec9f06b6 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 26 Aug 2025 21:09:14 -0700 Subject: [PATCH 06/10] test --- .../models/deepseek_v3/hf_implementation.py | 177 ++++++++++++++++++ torchtitan/models/deepseek_v3/model/model.py | 12 +- .../deepseek_v3/model/state_dict_adapter.py | 18 +- .../train_configs/deepseek_v3_671b.toml | 4 +- torchtitan/models/moe.py | 28 ++- torchtitan/train.py | 59 +++++- 6 files changed, 282 insertions(+), 16 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/hf_implementation.py diff --git a/torchtitan/models/deepseek_v3/hf_implementation.py b/torchtitan/models/deepseek_v3/hf_implementation.py new file mode 100644 index 000000000..c34f379db --- /dev/null +++ b/torchtitan/models/deepseek_v3/hf_implementation.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Hugging Face implementation for DeepSeek-V3 model inference. +""" + +import argparse +import gc +import os +import time + +import torch + + +def print_gpu_memory_usage(message=""): + """Print current GPU memory usage.""" + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / (1024**3) + reserved = torch.cuda.memory_reserved() / (1024**3) + print( + f"GPU Memory ({message}): Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB" + ) + + +def run_huggingface_implementation(args, _): + """Run the DeepSeek-V3 model using Hugging Face Transformers.""" + # Disable Hugging Face cache + from transformers import AutoConfig, AutoModelForCausalLM + + # We're not using the tokenizer anymore, using fake inputs instead + # Use local path for model weights if specified, otherwise use model_name + model_path = args.model_path + print(f"Loading model from local path: {model_path}") + start_time = time.time() + + quantization_config = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", # Updated from fp8 to fbgemm_fp8 + "weight_block_size": [128, 128], + } + print(f"Using quantization config: {quantization_config}") + + # ============= Change config to only use a few layers ============= + config = None + if args.num_layers > 0: + # Try to load config from local path first, fall back to model_name if needed + try: + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: + print(f"Could not load config from local path: {e}") + print(f"Falling back to loading config from {args.model_name}") + config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) + + config.n_group = 1 # make n_groups = a huge group + config.topk_group = 1 # make topk_group = a huge group + # tailer the first several layers + config.num_hidden_layers = args.num_layers + # Explicitly set rope_interleaved to True to use the interleaved rope implementation + config.rope_interleaved = True + print(f"Modified config to use only {args.num_layers} layers") + print(f"Config of Deepseek: {config}") + + # Load the model from local path + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="cuda", # Try with specific device first + config=config, + trust_remote_code=True, + # Disable features that can cause issues with device mapping + attn_implementation="eager", # Use standard attention instead of flash attention + quantization_config=quantization_config, + local_files_only=True, # Only use local files, don't fetch from cache + use_auth_token=False, # Don't try to authenticate with HF + ) + + print(f"Model loaded in {time.time() - start_time:.2f} seconds") + print_gpu_memory_usage("After loading model") + + # Get the device where the model is loaded + device = next(model.parameters()).device + print(f"Model is on device: {device}") + + # Create fake input directly on the correct device + print("\nCreating fake input with the same shape as tokenized input") + + # Define sequence length for fake input + seq_length = 2048 # You can adjust this based on your needs + vocab_size = 50000 + + with torch.no_grad(): + # Create fake input_ids directly on the device - using random integers between 0 and 50000 (typical vocab size) + torch.manual_seed(42) + tokens = torch.randint( + 0, vocab_size, (1, seq_length), dtype=torch.long, device="cuda" + ) + + # Create fake attention_mask directly on the device - all 1s for full attention + attention_mask = torch.ones((1, seq_length), dtype=torch.long, device=device) + + # Create inputs dictionary similar to what tokenizer would produce + inputs = {"input_ids": tokens, "attention_mask": attention_mask} + + # Print input information + print(f"Fake input token IDs: {inputs['input_ids'][0][:10].cpu().numpy()}...") + print(f"Fake input shape: {inputs['input_ids'].shape}") + print(f"Input tensors device: {inputs['input_ids'].device}") + + # Run a single forward pass + print("\nRunning single forward pass...") + start_time = time.time() + + with torch.no_grad(): + # Forward pass through the model with output_hidden_states=True and output_attentions=True + outputs = model( + **inputs, output_hidden_states=True, output_attentions=True, use_cache=False + ) + + forward_time = time.time() - start_time + + # Get the logits from the output + logits = outputs.logits if hasattr(outputs, "logits") else outputs + + # Get the predictions for the next token (highest probability) + next_token_logits = logits[:, -1, :] + print(f"\nNext token logits : {next_token_logits}") + next_token_probs = torch.softmax(next_token_logits, dim=-1) + print(f"\nNext token probabilities: {next_token_probs}") + top_k_values, top_k_indices = torch.topk(next_token_probs, 5, dim=-1) + + print("\nForward Pass Results:") + print(f"- Output logits shape: {logits.shape}") + print(f"- Sequence length: {logits.shape[1]}") + print(f"- Vocabulary size: {logits.shape[2]}") + + print( + "\nTop 5 predicted next tokens (showing IDs only since we're not using tokenizer):" + ) + for i, (value, index) in enumerate(zip(top_k_values[0], top_k_indices[0])): + print(f" {i+1}. Token ID: {index} - Probability: {value.item():.4f}") + + print(f"\nForward pass stats:") + print(f"- Time: {forward_time:.4f} seconds") + print(f"- Input tokens: {inputs['input_ids'].shape[1]}") + print(f"- Tokens per second: {inputs['input_ids'].shape[1] / forward_time:.2f}") + print_gpu_memory_usage("After forward pass") + + +def main(): + parser = argparse.ArgumentParser(description="Load and test DeepSeek-V3 model") + parser.add_argument( + "--num_layers", + type=int, + default=5, # tailered to 5 layers for 671B model + help="Number of layers to use (0 for all layers)", + ) + + # Hugging Face specific arguments + parser.add_argument( + "--model_path", + type=str, + default="/data/users/jianiw/model/DeepSeek-V3.1-Base", + help="Hugging Face model name or path", + ) + + args = parser.parse_args() + run_huggingface_implementation(args, None) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index e2c4bbeda..4cf46999a 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -11,7 +11,7 @@ from torch import nn from torchtitan.models.attention import build_attention -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import FeedForward, MoE, print_tensor_stats from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -295,9 +295,12 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ + print_tensor_stats(f"input of TransformerBlock {self.layer_id}: ", x) x = x + self.attention(self.attention_norm(x), freqs_cis) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + x = self.ffn_norm(x) + print_tensor_stats(f"After ffn_norm : ", x) + x = x + self.moe(x) else: x = x + self.feed_forward(self.ffn_norm(x)) return x @@ -385,8 +388,11 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + + token_inputs = h for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + # reset before each layer + h = layer(token_inputs, self.freqs_cis) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 9e2e6f952..1704b28b9 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -453,14 +453,18 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer, titan_abstract_key, value.device_mesh ) - if stacked_value is not None: - local_tensor = stacked_value._local_tensor - tensor_list = local_tensor.tolist() - # Save to JSON file - import json - with open(f'my_implementation_tensor_{new_key}.json', 'w') as f: - json.dump(tensor_list, f) + if stacked_value is not None: + if torch.distributed.get_rank() == 0: + print("saving tensor to json file") + local_tensor = stacked_value._local_tensor + print("stacked_value: ", stacked_value.shape, stacked_value.device_mesh, stacked_value.placements, "local_tensor: ", local_tensor.shape) + + tensor_list = local_tensor.tolist() + # Save to JSON file + import json + with open(f'my_imp_tensor_222_{new_key}.json', 'w') as f: + json.dump(tensor_list, f) state_dict[new_key] = stacked_value elif "layers" in key: diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 6e01dc477..d8911ae02 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -35,8 +35,8 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 4 -seq_len = 4096 +local_batch_size = 2 +seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 compile = false diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf..f21e88a06 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -14,6 +14,14 @@ from torchtitan.distributed.expert_parallel import expert_parallel +def print_tensor_stats(name, tensor): + mean = tensor.mean().item() + std = tensor.std().item() + min_val = tensor.min().item() + max_val = tensor.max().item() + print( + f"{name} - Shape: {tensor.shape} Mean: {mean:.6f}, Min: {min_val:.6f}, Max: {max_val:.6f}, Std: {std:.6f}, First 10 values: {tensor.flatten()[:10].tolist()}" + ) @dataclass class MoEArgs: num_experts: int = 8 @@ -367,9 +375,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ + + print_tensor_stats("input of MoE module: ", x) + bs, slen, dim = x.shape x = x.view(-1, dim) - + # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) ( @@ -378,6 +389,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens_per_expert, ) = self.router(x, self.expert_bias) + print_tensor_stats("top_scores of router: ", top_scores) + # tokens_per_expert will be used to update the expert bias for load balancing. # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- @@ -400,6 +413,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens_per_expert, ) = self.reorderer(top_scores, selected_experts_indices) + # print_tensor_stats("selected_experts_indices of reorderer: ", selected_experts_indices) + # Print first 10 elements of selected_experts_indices + print(f"First 10 elements of selected_experts_indices: {selected_experts_indices.flatten()[:10].tolist()}") + + # shape (bs*slen*top_k, dim) token_indices_experts_sorted = token_indices_experts_sorted.reshape( -1, 1 @@ -414,9 +432,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + print_tensor_stats("routed_input of GroupedExperts module: ", routed_input) + # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + print_tensor_stats("routed_output of GroupedExperts module: ", routed_output) + if not self.score_before_experts: routed_output = ( routed_output.to(torch.float32) @@ -426,13 +448,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shared expert if self.shared_experts is not None: out = self.shared_experts(x) + print_tensor_stats("out of Shard Experts module: ", out) else: out = torch.zeros_like(x) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output ) + + out = out.reshape(bs, slen, dim) + print_tensor_stats("out of MoE module: ", out) return out def init_weights( diff --git a/torchtitan/train.py b/torchtitan/train.py index ff6f55c41..87c1caf18 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -464,10 +464,30 @@ def forward_backward_step( assert len(model_parts) == 1 with self.maybe_enable_amp: pred = model_parts[0](inputs) - loss = self.loss_fn(pred, labels) + + print("\nForward Pass Results:") + print(f"- Output logits shape: {pred.shape}") + print(f"- Sequence length: {pred.shape[1]}") + print(f"- Vocabulary size: {pred.shape[2]}") + + # Get the predictions for the next token (highest probability) + next_token_logits = pred[:, -1, :] + print(f"\nNext token logits : {next_token_logits}") + next_token_probs = torch.softmax(next_token_logits, dim=-1) + print(f"\nNext token probabilities: {next_token_probs}") + top_k_values, top_k_indices = torch.topk(next_token_probs, 5, dim=-1) + + print("Top K values: ", top_k_values) + print("Top K indices: ", top_k_indices) + + print("\nTop 5 predicted next tokens (showing IDs only since we're not using tokenizer):") + for i, (value, index) in enumerate(zip(top_k_values[0], top_k_indices[0])): + print(f" {i+1}. Token ID: {index} - Probability: {value.item():.4f}") + + # loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - loss.backward() + # loss.backward() return loss @@ -486,8 +506,41 @@ def train_step( # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): - input_dict, labels = next(data_iterator) + # input_dict, labels = next(data_iterator) + + print("\nCreating fake input with the same shape as tokenized input") + + # Define sequence length for fake input + seq_length = self.job_config.training.seq_len + seq_length = 2048 + + with torch.no_grad(): + # Create fake input_ids directly on the device - using random integers between 0 and 50000 (typical vocab size) + torch.manual_seed(42) + input_ids = torch.randint(0, 50000, (1, seq_length), dtype=torch.long, device=self.device) + + # Create fake attention_mask directly on the device - all 1s for full attention + attention_mask = torch.ones((1, seq_length), dtype=torch.long, device=self.device) + + # Create inputs dictionary similar to what tokenizer would produce + input_dict = { + "input": input_ids, + "attention_mask": attention_mask + } + + # Create fake labels (same as attention_mask for simplicity) + labels = attention_mask.clone() + + # Print input information + print(f"Fake input token IDs: {input_ids[0][:10].cpu().numpy()}...") + print(f"Fake input shape: {input_ids.shape}") + print(f"Input tensors device: {input_ids.device}") + + print("\nRunning single forward pass...") + loss = self.forward_backward_step(input_dict, labels) + + return accumulated_losses.append(loss.detach()) grad_norm = dist_utils.clip_grad_norm_( From 7e331abd1bcd64a6e1072ceced2ce68229a91cd6 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 28 Aug 2025 17:05:46 -0700 Subject: [PATCH 07/10] clean up --- .../checkpoint_conversion/convert_to_hf.py | 2 +- torchtitan/components/checkpoint.py | 2 +- torchtitan/models/deepseek_v3/__init__.py | 2 +- .../models/deepseek_v3/hf_implementation.py | 177 ------------------ .../models/deepseek_v3/infra/parallelize.py | 1 - torchtitan/models/deepseek_v3/model/model.py | 12 +- .../deepseek_v3/model/state_dict_adapter.py | 132 +++++++++---- .../train_configs/deepseek_v3_671b.toml | 8 +- .../models/llama3/model/state_dict_adapter.py | 1 - torchtitan/models/moe.py | 28 +-- torchtitan/train.py | 76 +------- 11 files changed, 106 insertions(+), 335 deletions(-) delete mode 100644 torchtitan/models/deepseek_v3/hf_implementation.py diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index f0ea17cc6..4cf9791d6 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -44,7 +44,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat storage_writer = HuggingFaceStorageWriter( path=output_dir, save_distributed=True, - fqn_to_index_mapping=sd_adapter.fqn_to_index_mapping, + fqn_to_index_mapping=None, enable_consolidation=True, thread_count_consolidation=5, ) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ca24420c1..bf835d241 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -418,7 +418,7 @@ def dcp_load( ) state_dict = self.sd_adapter.from_hf(hf_state_dict) - + self.states[MODEL].load_state_dict(state_dict) else: dcp.load(state_dict, checkpoint_id=checkpoint_id) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index be17ecf5e..106fbb381 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -46,7 +46,7 @@ route_norm=True, score_before_experts=False, ), - q_lora_rank=0, + q_lora_rank=256, # for test, original is 0 kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, diff --git a/torchtitan/models/deepseek_v3/hf_implementation.py b/torchtitan/models/deepseek_v3/hf_implementation.py deleted file mode 100644 index c34f379db..000000000 --- a/torchtitan/models/deepseek_v3/hf_implementation.py +++ /dev/null @@ -1,177 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Hugging Face implementation for DeepSeek-V3 model inference. -""" - -import argparse -import gc -import os -import time - -import torch - - -def print_gpu_memory_usage(message=""): - """Print current GPU memory usage.""" - if torch.cuda.is_available(): - allocated = torch.cuda.memory_allocated() / (1024**3) - reserved = torch.cuda.memory_reserved() / (1024**3) - print( - f"GPU Memory ({message}): Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB" - ) - - -def run_huggingface_implementation(args, _): - """Run the DeepSeek-V3 model using Hugging Face Transformers.""" - # Disable Hugging Face cache - from transformers import AutoConfig, AutoModelForCausalLM - - # We're not using the tokenizer anymore, using fake inputs instead - # Use local path for model weights if specified, otherwise use model_name - model_path = args.model_path - print(f"Loading model from local path: {model_path}") - start_time = time.time() - - quantization_config = { - "activation_scheme": "dynamic", - "fmt": "e4m3", - "quant_method": "fp8", # Updated from fp8 to fbgemm_fp8 - "weight_block_size": [128, 128], - } - print(f"Using quantization config: {quantization_config}") - - # ============= Change config to only use a few layers ============= - config = None - if args.num_layers > 0: - # Try to load config from local path first, fall back to model_name if needed - try: - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - except Exception as e: - print(f"Could not load config from local path: {e}") - print(f"Falling back to loading config from {args.model_name}") - config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True) - - config.n_group = 1 # make n_groups = a huge group - config.topk_group = 1 # make topk_group = a huge group - # tailer the first several layers - config.num_hidden_layers = args.num_layers - # Explicitly set rope_interleaved to True to use the interleaved rope implementation - config.rope_interleaved = True - print(f"Modified config to use only {args.num_layers} layers") - print(f"Config of Deepseek: {config}") - - # Load the model from local path - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - device_map="cuda", # Try with specific device first - config=config, - trust_remote_code=True, - # Disable features that can cause issues with device mapping - attn_implementation="eager", # Use standard attention instead of flash attention - quantization_config=quantization_config, - local_files_only=True, # Only use local files, don't fetch from cache - use_auth_token=False, # Don't try to authenticate with HF - ) - - print(f"Model loaded in {time.time() - start_time:.2f} seconds") - print_gpu_memory_usage("After loading model") - - # Get the device where the model is loaded - device = next(model.parameters()).device - print(f"Model is on device: {device}") - - # Create fake input directly on the correct device - print("\nCreating fake input with the same shape as tokenized input") - - # Define sequence length for fake input - seq_length = 2048 # You can adjust this based on your needs - vocab_size = 50000 - - with torch.no_grad(): - # Create fake input_ids directly on the device - using random integers between 0 and 50000 (typical vocab size) - torch.manual_seed(42) - tokens = torch.randint( - 0, vocab_size, (1, seq_length), dtype=torch.long, device="cuda" - ) - - # Create fake attention_mask directly on the device - all 1s for full attention - attention_mask = torch.ones((1, seq_length), dtype=torch.long, device=device) - - # Create inputs dictionary similar to what tokenizer would produce - inputs = {"input_ids": tokens, "attention_mask": attention_mask} - - # Print input information - print(f"Fake input token IDs: {inputs['input_ids'][0][:10].cpu().numpy()}...") - print(f"Fake input shape: {inputs['input_ids'].shape}") - print(f"Input tensors device: {inputs['input_ids'].device}") - - # Run a single forward pass - print("\nRunning single forward pass...") - start_time = time.time() - - with torch.no_grad(): - # Forward pass through the model with output_hidden_states=True and output_attentions=True - outputs = model( - **inputs, output_hidden_states=True, output_attentions=True, use_cache=False - ) - - forward_time = time.time() - start_time - - # Get the logits from the output - logits = outputs.logits if hasattr(outputs, "logits") else outputs - - # Get the predictions for the next token (highest probability) - next_token_logits = logits[:, -1, :] - print(f"\nNext token logits : {next_token_logits}") - next_token_probs = torch.softmax(next_token_logits, dim=-1) - print(f"\nNext token probabilities: {next_token_probs}") - top_k_values, top_k_indices = torch.topk(next_token_probs, 5, dim=-1) - - print("\nForward Pass Results:") - print(f"- Output logits shape: {logits.shape}") - print(f"- Sequence length: {logits.shape[1]}") - print(f"- Vocabulary size: {logits.shape[2]}") - - print( - "\nTop 5 predicted next tokens (showing IDs only since we're not using tokenizer):" - ) - for i, (value, index) in enumerate(zip(top_k_values[0], top_k_indices[0])): - print(f" {i+1}. Token ID: {index} - Probability: {value.item():.4f}") - - print(f"\nForward pass stats:") - print(f"- Time: {forward_time:.4f} seconds") - print(f"- Input tokens: {inputs['input_ids'].shape[1]}") - print(f"- Tokens per second: {inputs['input_ids'].shape[1] / forward_time:.2f}") - print_gpu_memory_usage("After forward pass") - - -def main(): - parser = argparse.ArgumentParser(description="Load and test DeepSeek-V3 model") - parser.add_argument( - "--num_layers", - type=int, - default=5, # tailered to 5 layers for 671B model - help="Number of layers to use (0 for all layers)", - ) - - # Hugging Face specific arguments - parser.add_argument( - "--model_path", - type=str, - default="/data/users/jianiw/model/DeepSeek-V3.1-Base", - help="Hugging Face model name or path", - ) - - args = parser.parse_args() - run_huggingface_implementation(args, None) - - -if __name__ == "__main__": - main() diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index b2ef2790c..c77250d0f 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -36,7 +36,6 @@ def parallelize_deepseekv3( job_config: JobConfig, ): world_mesh = parallel_dims.world_mesh - print(f"In parallelize_deepseekv3, world mesh is {world_mesh}") # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 4cf46999a..e2c4bbeda 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -11,7 +11,7 @@ from torch import nn from torchtitan.models.attention import build_attention -from torchtitan.models.moe import FeedForward, MoE, print_tensor_stats +from torchtitan.models.moe import FeedForward, MoE from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -295,12 +295,9 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - print_tensor_stats(f"input of TransformerBlock {self.layer_id}: ", x) x = x + self.attention(self.attention_norm(x), freqs_cis) if self.moe_enabled: - x = self.ffn_norm(x) - print_tensor_stats(f"After ffn_norm : ", x) - x = x + self.moe(x) + x = x + self.moe(self.ffn_norm(x)) else: x = x + self.feed_forward(self.ffn_norm(x)) return x @@ -388,11 +385,8 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens - - token_inputs = h for layer in self.layers.values(): - # reset before each layer - h = layer(token_inputs, self.freqs_cis) + h = layer(h, self.freqs_cis) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 1704b28b9..80382d324 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -6,7 +6,6 @@ import re -from threading import local from typing import Any, Dict import torch @@ -37,9 +36,6 @@ def __init__( self.from_hf_map = { "model.embed_tokens.weight": "tok_embeddings.weight", # Attention Module - "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attention.wq_a.weight", - "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attention.q_norm.weight", - "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attention.wq_b.weight", "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attention.wkv_a.weight", "model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attention.kv_norm.weight", "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attention.wkv_b.weight", @@ -64,6 +60,22 @@ def __init__( "lm_head.weight": "output.weight", } + # Adjustments for from_hf_map based on model architecture + if model_args.q_lora_rank != 0: + self.from_hf_map.update( + { + "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attention.wq_a.weight", + "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attention.wq_b.weight", + } + ) + else: + self.from_hf_map.update( + { + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + } + ) + # Store metadata for GroupedExperts <-> individual experts conversion self.grouped_expert_weight_placements = {} # {titan_abstract_key: placements} self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} @@ -109,7 +121,8 @@ def _calculate_strided_shard_shard_indices( != dim_size_to_split ): raise ValueError( - f"Not supported split for strided_shard_dim_degree {strided_shard_dim_degree}, shard_dim_degree {shard_dim_degree}, dim_size_to_split {dim_size_to_split}" + f"Not supported split for strided_shard_dim_degree {strided_shard_dim_degree}, " + f"shard_dim_degree {shard_dim_degree}, dim_size_to_split {dim_size_to_split}" ) start_index = block_size * ( @@ -187,7 +200,6 @@ def _caculate_indices_from_placements( return start_index, end_index - def _get_local_experts_weights( self, abstract_key: str, @@ -277,7 +289,7 @@ def _get_local_experts_weights( local_expert_tensors[expert_key] = expert_dtensor return local_expert_tensors - + def _concatenate_local_expert_weights( self, expert_weights_by_layer: dict[str, Any], @@ -298,14 +310,17 @@ def _concatenate_local_expert_weights( sorted_expert_ids = sorted(experts.keys()) sorted_experts = [experts[i] for i in sorted_expert_ids] local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor - + assert ( abstract_key in self.grouped_expert_weight_placements and abstract_key in self.grouped_expert_weight_shape - ), f"GroupedExperts weight metadata {self.grouped_expert_weight_placements} {self.grouped_expert_weight_shape} can not be None!" + ), "GroupedExperts weight metadata (placements, shape) can not be None!" stacked_dtensor = DTensor.from_local( - local_tensor, device_mesh, self.grouped_expert_weight_placements[abstract_key], run_check=False + local_tensor, + device_mesh, + self.grouped_expert_weight_placements[abstract_key], + run_check=False, ) # Remove these experts from the tracking dict to free memory @@ -317,6 +332,38 @@ def _concatenate_local_expert_weights( return None + def _split_experts_weights( + self, weight: torch.Tensor, n_experts: int + ) -> list[torch.Tensor]: + """ + Split the weights of the experts into a list of tensors. + """ + split_weight = torch.split(weight, weight.shape[0] // n_experts, dim=0) + return split_weight + + def _concatenate_expert_weights( + self, expert_weights_by_layer: dict[str, Any], n_experts: int + ) -> torch.Tensor: + """ + Concatenate the weights of separate experts into GroupedExpert weights. + """ + for layer, abstract_keys in list(expert_weights_by_layer.items()): + for abstract_key, experts in list(abstract_keys.items()): + # If we have all the experts for this abstract_key, concatenate them + if len(experts) == n_experts: + sorted_expert_ids = sorted(experts.keys()) + sorted_experts = [experts[i] for i in sorted_expert_ids] + stacked_tensor = torch.stack(sorted_experts, dim=0) + + # Remove these experts from the tracking dict to free memory + del expert_weights_by_layer[layer][abstract_key] + if not expert_weights_by_layer[layer]: + del expert_weights_by_layer[layer] + + return stacked_tensor + + return None + def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ Dequantize the weights from float8 to float32. @@ -365,6 +412,9 @@ def _add_quantization_scale_inv_tensors( weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( expected_scale_shape, dtype=torch.float32 ) + print( + f"In _add_quantize_scale_inv_tensors, the added weight_scale_inv_state_dict: {weight_scale_inv_state_dict.keys()}" + ) state_dict.update(weight_scale_inv_state_dict) return state_dict @@ -385,18 +435,30 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_abstract_key = to_hf_map[abstract_key] # Store the GroupedExperts Weight metadata for from_hf() - self.grouped_expert_weight_placements[abstract_key] = value.placements - self.grouped_expert_weight_shape[abstract_key] = value.shape - - # Split GroupedExperts weight to local individual expert weights - local_expert_fqn = self._get_local_experts_weights( - new_abstract_key, - abstract_key, - layer_num, - value, - ) - - hf_state_dict.update(local_expert_fqn) + if isinstance(value, DTensor): + self.grouped_expert_weight_placements[ + abstract_key + ] = value.placements + self.grouped_expert_weight_shape[abstract_key] = value.shape + + # Split GroupedExperts weight to local individual expert weights + local_expert_fqn = self._get_local_experts_weights( + new_abstract_key, + abstract_key, + layer_num, + value, + ) + hf_state_dict.update(local_expert_fqn) + + else: + # keep this path for offline conversion + split_values = self._split_experts_weights( + value, self.model_args.moe_args.num_experts + ) + + for expert_num in range(0, self.model_args.moe_args.num_experts): + new_key = new_abstract_key.format(layer_num, expert_num) + hf_state_dict[new_key] = split_values[expert_num].squeeze() elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) @@ -445,26 +507,16 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_num ] = value - # try to concat the expert's weight into GroupedExperts' weight. - # stacked_value = self._concatenate_expert_weights( - # expert_weights_by_layer, self.model_args.moe_args.num_experts - # ) - stacked_value = self._concatenate_local_expert_weights( - expert_weights_by_layer, titan_abstract_key, value.device_mesh - ) - + if isinstance(value, DTensor): + stacked_value = self._concatenate_local_expert_weights( + expert_weights_by_layer, titan_abstract_key, value.device_mesh + ) + else: # keep this path to be compatibile with offline conversion + stacked_value = self._concatenate_expert_weights( + expert_weights_by_layer, self.model_args.moe_args.num_experts + ) if stacked_value is not None: - if torch.distributed.get_rank() == 0: - print("saving tensor to json file") - local_tensor = stacked_value._local_tensor - print("stacked_value: ", stacked_value.shape, stacked_value.device_mesh, stacked_value.placements, "local_tensor: ", local_tensor.shape) - - tensor_list = local_tensor.tolist() - # Save to JSON file - import json - with open(f'my_imp_tensor_222_{new_key}.json', 'w') as f: - json.dump(tensor_list, f) state_dict[new_key] = stacked_value elif "layers" in key: diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index d8911ae02..f4556f08b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -46,12 +46,12 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 2 -expert_tensor_parallel_degree = 2 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable = false @@ -60,8 +60,6 @@ interval = 10 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" -initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base" -initial_load_in_hf=true [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 8e631a8af..2c386ece0 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -10,7 +10,6 @@ logger = logging.getLogger() -from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index f21e88a06..8be14ecbf 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -14,14 +14,6 @@ from torchtitan.distributed.expert_parallel import expert_parallel -def print_tensor_stats(name, tensor): - mean = tensor.mean().item() - std = tensor.std().item() - min_val = tensor.min().item() - max_val = tensor.max().item() - print( - f"{name} - Shape: {tensor.shape} Mean: {mean:.6f}, Min: {min_val:.6f}, Max: {max_val:.6f}, Std: {std:.6f}, First 10 values: {tensor.flatten()[:10].tolist()}" - ) @dataclass class MoEArgs: num_experts: int = 8 @@ -375,12 +367,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - - print_tensor_stats("input of MoE module: ", x) - bs, slen, dim = x.shape x = x.view(-1, dim) - + # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) ( @@ -389,8 +378,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens_per_expert, ) = self.router(x, self.expert_bias) - print_tensor_stats("top_scores of router: ", top_scores) - # tokens_per_expert will be used to update the expert bias for load balancing. # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- @@ -413,11 +400,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens_per_expert, ) = self.reorderer(top_scores, selected_experts_indices) - # print_tensor_stats("selected_experts_indices of reorderer: ", selected_experts_indices) - # Print first 10 elements of selected_experts_indices - print(f"First 10 elements of selected_experts_indices: {selected_experts_indices.flatten()[:10].tolist()}") - - # shape (bs*slen*top_k, dim) token_indices_experts_sorted = token_indices_experts_sorted.reshape( -1, 1 @@ -432,13 +414,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) - print_tensor_stats("routed_input of GroupedExperts module: ", routed_input) - # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) - print_tensor_stats("routed_output of GroupedExperts module: ", routed_output) - if not self.score_before_experts: routed_output = ( routed_output.to(torch.float32) @@ -448,17 +426,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shared expert if self.shared_experts is not None: out = self.shared_experts(x) - print_tensor_stats("out of Shard Experts module: ", out) else: out = torch.zeros_like(x) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output ) - - out = out.reshape(bs, slen, dim) - print_tensor_stats("out of MoE module: ", out) return out def init_weights( diff --git a/torchtitan/train.py b/torchtitan/train.py index 87c1caf18..9b69fd679 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -107,7 +107,6 @@ def __init__(self, job_config: JobConfig): ) world_mesh = parallel_dims.world_mesh - print(f"Worldmesh in trainer init : {world_mesh}") if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() @@ -258,9 +257,6 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - print( - f"the world mesh before applying parallelize_fn {parallel_dims.world_mesh}" - ) model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) @@ -315,8 +311,7 @@ def __init__(self, job_config: JobConfig): checkpoint_config=job_config.checkpoint, sd_adapter=( self.train_spec.state_dict_adapter( - model_args, - job_config.model.hf_assets_path, + model_args, job_config.model.hf_assets_path ) if self.train_spec.state_dict_adapter else None @@ -464,30 +459,10 @@ def forward_backward_step( assert len(model_parts) == 1 with self.maybe_enable_amp: pred = model_parts[0](inputs) - - print("\nForward Pass Results:") - print(f"- Output logits shape: {pred.shape}") - print(f"- Sequence length: {pred.shape[1]}") - print(f"- Vocabulary size: {pred.shape[2]}") - - # Get the predictions for the next token (highest probability) - next_token_logits = pred[:, -1, :] - print(f"\nNext token logits : {next_token_logits}") - next_token_probs = torch.softmax(next_token_logits, dim=-1) - print(f"\nNext token probabilities: {next_token_probs}") - top_k_values, top_k_indices = torch.topk(next_token_probs, 5, dim=-1) - - print("Top K values: ", top_k_values) - print("Top K indices: ", top_k_indices) - - print("\nTop 5 predicted next tokens (showing IDs only since we're not using tokenizer):") - for i, (value, index) in enumerate(zip(top_k_values[0], top_k_indices[0])): - print(f" {i+1}. Token ID: {index} - Probability: {value.item():.4f}") - - # loss = self.loss_fn(pred, labels) + loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - # loss.backward() + loss.backward() return loss @@ -506,41 +481,8 @@ def train_step( # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): - # input_dict, labels = next(data_iterator) - - print("\nCreating fake input with the same shape as tokenized input") - - # Define sequence length for fake input - seq_length = self.job_config.training.seq_len - seq_length = 2048 - - with torch.no_grad(): - # Create fake input_ids directly on the device - using random integers between 0 and 50000 (typical vocab size) - torch.manual_seed(42) - input_ids = torch.randint(0, 50000, (1, seq_length), dtype=torch.long, device=self.device) - - # Create fake attention_mask directly on the device - all 1s for full attention - attention_mask = torch.ones((1, seq_length), dtype=torch.long, device=self.device) - - # Create inputs dictionary similar to what tokenizer would produce - input_dict = { - "input": input_ids, - "attention_mask": attention_mask - } - - # Create fake labels (same as attention_mask for simplicity) - labels = attention_mask.clone() - - # Print input information - print(f"Fake input token IDs: {input_ids[0][:10].cpu().numpy()}...") - print(f"Fake input shape: {input_ids.shape}") - print(f"Input tensors device: {input_ids.device}") - - print("\nRunning single forward pass...") - + input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) - - return accumulated_losses.append(loss.detach()) grad_norm = dist_utils.clip_grad_norm_( @@ -597,16 +539,6 @@ def train_step( def train(self): job_config = self.job_config - # Following hacky print only works for debug_model - # w1 = self.model_parts[0].layers["1"].moe.experts.w1 - # w2 = self.model_parts[0].layers["1"].moe.experts.w2 - # w3 = self.model_parts[0].layers["1"].moe.experts.w3 - - # logger.info(f"w1 placements is: {w1.placements}, {type(w1.placements)}") - # logger.info(f"w2 placements is: {w2.placements}") - # logger.info(f"w3 placements is: {w3.placements}") - # logger.info(f"device mesh: {self.parallel_dims.world_mesh}, {self.parallel_dims.world_mesh.mesh_dim_names} {self.parallel_dims.world_mesh['dp_shard']}") - self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}") From 6e6d5cd68fe1e1a09aef2a1bb4279164b132529b Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 28 Aug 2025 17:16:29 -0700 Subject: [PATCH 08/10] further clean up --- scripts/checkpoint_conversion/convert_to_hf.py | 2 +- torchtitan/components/checkpoint.py | 1 - torchtitan/models/deepseek_v3/__init__.py | 4 ++-- torchtitan/models/deepseek_v3/model/state_dict_adapter.py | 6 ------ .../models/deepseek_v3/train_configs/debug_model.toml | 6 +++--- .../deepseek_v3/train_configs/deepseek_v3_671b.toml | 8 ++++---- 6 files changed, 10 insertions(+), 17 deletions(-) diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index 4cf9791d6..f0ea17cc6 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -44,7 +44,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat storage_writer = HuggingFaceStorageWriter( path=output_dir, save_distributed=True, - fqn_to_index_mapping=None, + fqn_to_index_mapping=sd_adapter.fqn_to_index_mapping, enable_consolidation=True, thread_count_consolidation=5, ) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index bf835d241..fcec60185 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -418,7 +418,6 @@ def dcp_load( ) state_dict = self.sd_adapter.from_hf(hf_state_dict) - self.states[MODEL].load_state_dict(state_dict) else: dcp.load(state_dict, checkpoint_id=checkpoint_id) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 106fbb381..1c3d2b19d 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -46,7 +46,7 @@ route_norm=True, score_before_experts=False, ), - q_lora_rank=256, # for test, original is 0 + q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, @@ -135,7 +135,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=4, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 80382d324..7043a030e 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -378,8 +378,6 @@ def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: ) # update the weight and remove the scale_inv tensor state_dict[key] = dequantized_weight - - # state_dict[key] = weight scale_inv_keys.append(key + "_scale_inv") for key in scale_inv_keys: @@ -411,10 +409,6 @@ def _add_quantization_scale_inv_tensors( # add weight_scale_inv to the state_dict weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( expected_scale_shape, dtype=torch.float32 - ) - print( - f"In _add_quantize_scale_inv_tensors, the added weight_scale_inv_state_dict: {weight_scale_inv_state_dict.keys()}" - ) state_dict.update(weight_scale_inv_state_dict) return state_dict diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 35e7e5108..dc9f37f44 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -47,13 +47,13 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 4 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "1F1B" context_parallel_degree = 1 -expert_parallel_degree = 2 -expert_tensor_parallel_degree = 4 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index f4556f08b..ad238839a 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -35,10 +35,10 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 2 -seq_len = 2048 +local_batch_size = 4 +seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 10_000 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) @@ -56,7 +56,7 @@ expert_tensor_parallel_degree = 1 [checkpoint] enable = false folder = "checkpoint" -interval = 10 +interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" From aa56404edc200eac440a5e3f932bfab92429eb4d Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 28 Aug 2025 17:25:20 -0700 Subject: [PATCH 09/10] fix lint --- torchtitan/models/deepseek_v3/model/state_dict_adapter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 7043a030e..7886b4dc7 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -409,6 +409,7 @@ def _add_quantization_scale_inv_tensors( # add weight_scale_inv to the state_dict weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( expected_scale_shape, dtype=torch.float32 + ) state_dict.update(weight_scale_inv_state_dict) return state_dict From 78896d0ce3b59e9121f19f4260f06428c04b6114 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 29 Aug 2025 11:45:31 -0700 Subject: [PATCH 10/10] fix comments --- .../deepseek_v3/model/state_dict_adapter.py | 156 +++++++++++------- 1 file changed, 98 insertions(+), 58 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 7886b4dc7..e947d7069 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -6,15 +6,12 @@ import re -from typing import Any, Dict +from typing import Any import torch from torch.distributed.device_mesh import DeviceMesh - from torch.distributed.tensor import DTensor - from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard - from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import DeepSeekV3ModelArgs @@ -206,7 +203,7 @@ def _get_local_experts_weights( titan_abstract_key: str, layer_id: str, grouped_expert_weight: torch.Tensor, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """ Split GroupedExperts weight into individual expert weights for local processing. @@ -290,79 +287,116 @@ def _get_local_experts_weights( return local_expert_tensors - def _concatenate_local_expert_weights( + def _concatenate_expert_weights_dtensor( self, - expert_weights_by_layer: dict[str, Any], + expert_weights_by_layer: dict[str, dict[str, dict[int, torch.Tensor]]], abstract_key: str, + layer_num: str, device_mesh: DeviceMesh, - ) -> torch.Tensor: + ) -> torch.Tensor | None: """ - Try to concatenate the weights of separate experts into GroupedExperts weights. + Args: + expert_weights_by_layer: Dictionary tracking expert weights by layer, abstract key, and expert ID. + Structure: { + layer_id: { + abstract_key: { + expert_id: tensor_weight + } + } + } + Used to collect individual expert weights before concatenating them into GroupedExperts. + abstract_key: TorchTitan templage key with {} placeholders for layer and expert IDs + layer_num: Layer identifier + device_mesh: DeviceMesh for the target GroupedExperts weight DTensor + + Returns: + Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ - for layer in expert_weights_by_layer.keys(): - # If we have all the experts for this abstract_key, concatenate them - experts = expert_weights_by_layer[layer][abstract_key] - expected_n_experts = ( - self.local_experts_indices[abstract_key][1] - - self.local_experts_indices[abstract_key][0] - ) - if len(experts) == expected_n_experts: - sorted_expert_ids = sorted(experts.keys()) - sorted_experts = [experts[i] for i in sorted_expert_ids] - local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor - - assert ( - abstract_key in self.grouped_expert_weight_placements - and abstract_key in self.grouped_expert_weight_shape - ), "GroupedExperts weight metadata (placements, shape) can not be None!" - - stacked_dtensor = DTensor.from_local( - local_tensor, - device_mesh, - self.grouped_expert_weight_placements[abstract_key], - run_check=False, - ) + # If we have all the experts for this abstract_key, concatenate them + experts = expert_weights_by_layer[layer_num][abstract_key] + expected_n_experts = ( + self.local_experts_indices[abstract_key][1] + - self.local_experts_indices[abstract_key][0] + ) + if len(experts) < expected_n_experts: + return None + + sorted_expert_ids = sorted(experts.keys()) + sorted_experts = [experts[i] for i in sorted_expert_ids] + local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor - # Remove these experts from the tracking dict to free memory - del expert_weights_by_layer[layer][abstract_key] - if not expert_weights_by_layer[layer]: - del expert_weights_by_layer[layer] + assert ( + abstract_key in self.grouped_expert_weight_placements + and abstract_key in self.grouped_expert_weight_shape + ), "GroupedExperts weight metadata (placements, shape) can not be None!" + + stacked_dtensor = DTensor.from_local( + local_tensor, + device_mesh, + self.grouped_expert_weight_placements[abstract_key], + run_check=False, + ) - return stacked_dtensor + del expert_weights_by_layer[layer_num][abstract_key] + if not expert_weights_by_layer[layer_num]: + del expert_weights_by_layer[layer_num] - return None + return stacked_dtensor def _split_experts_weights( self, weight: torch.Tensor, n_experts: int ) -> list[torch.Tensor]: """ - Split the weights of the experts into a list of tensors. + Split the weights of the experts into a list of tensors. Used for offline conversion. + + NOTE: If we use this function for online conversion, torch.split() might incur communication + to gather the weight, which causing OOM. + """ split_weight = torch.split(weight, weight.shape[0] // n_experts, dim=0) return split_weight def _concatenate_expert_weights( - self, expert_weights_by_layer: dict[str, Any], n_experts: int - ) -> torch.Tensor: + self, + expert_weights_by_layer: dict[str, dict[str, dict[int, torch.Tensor]]], + abstract_key: str, + layer_num: str, + n_experts: int, + ) -> torch.Tensor | None: """ - Concatenate the weights of separate experts into GroupedExpert weights. + Concatenated GroupedExperts weight using torch.stack(). Used for offline conversion. + + Args: + expert_weights_by_layer: Dictionary tracking expert weights by layer, abstract key, and expert ID. + Structure: { + layer_id: { + abstract_key: { + expert_id: tensor_weight + } + } + } + Used to collect individual expert weights before concatenating them into GroupedExperts. + abstract_key: TorchTitan templage key with {} placeholders for layer and expert IDs + layer_num: Layer identifier + n_experts: Number of experts in the GroupedExperts module + + Returns: + Concatenated GroupedExperts weight if all experts are available, otherwise None """ - for layer, abstract_keys in list(expert_weights_by_layer.items()): - for abstract_key, experts in list(abstract_keys.items()): - # If we have all the experts for this abstract_key, concatenate them - if len(experts) == n_experts: - sorted_expert_ids = sorted(experts.keys()) - sorted_experts = [experts[i] for i in sorted_expert_ids] - stacked_tensor = torch.stack(sorted_experts, dim=0) + # If we have all the experts for this abstract_key, concatenate them + experts = expert_weights_by_layer[layer_num][abstract_key] + if len(experts) < n_experts: + return None - # Remove these experts from the tracking dict to free memory - del expert_weights_by_layer[layer][abstract_key] - if not expert_weights_by_layer[layer]: - del expert_weights_by_layer[layer] + sorted_expert_ids = sorted(experts.keys()) + sorted_experts = [experts[i] for i in sorted_expert_ids] + stacked_tensor = torch.stack(sorted_experts, dim=0) - return stacked_tensor + del expert_weights_by_layer[layer_num][abstract_key] + if not expert_weights_by_layer[layer_num]: + del expert_weights_by_layer[layer_num] - return None + return stacked_tensor def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -503,12 +537,18 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: ] = value if isinstance(value, DTensor): - stacked_value = self._concatenate_local_expert_weights( - expert_weights_by_layer, titan_abstract_key, value.device_mesh + stacked_value = self._concatenate_expert_weights_dtensor( + expert_weights_by_layer, + titan_abstract_key, + layer_num, + value.device_mesh, ) else: # keep this path to be compatibile with offline conversion stacked_value = self._concatenate_expert_weights( - expert_weights_by_layer, self.model_args.moe_args.num_experts + expert_weights_by_layer, + titan_abstract_key, + layer_num, + self.model_args.moe_args.num_experts, ) if stacked_value is not None: