Skip to content

Commit 00b1f7e

Browse files
msaurabhsaumishr
authored andcommitted
DCP: Dequantization and expert grouping for DSv3
1 parent 7354848 commit 00b1f7e

File tree

8 files changed

+1294
-139
lines changed

8 files changed

+1294
-139
lines changed

torchtitan/components/checkpoint.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
import torch.distributed as dist
2020
import torch.distributed.checkpoint as dcp
2121
import torch.nn as nn
22-
from torch.distributed.checkpoint import (
23-
HuggingFaceStorageReader,
24-
HuggingFaceStorageWriter,
25-
)
22+
from torch.distributed.checkpoint import HuggingFaceStorageWriter
2623
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
2724
from torch.distributed.checkpoint.state_dict import (
2825
get_model_state_dict,
@@ -37,6 +34,12 @@
3734
from torchtitan.components.lr_scheduler import LRSchedulersContainer
3835
from torchtitan.components.optimizer import OptimizersContainer
3936
from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP
37+
from torchtitan.models.deepseek_v3.model.deepseek_v3_storage_reader import (
38+
DeepSeekV3HuggingFaceStorageReader,
39+
)
40+
from torchtitan.models.deepseek_v3.model.deepseek_v3_planner import (
41+
DeepSeekV3LoadPlanner,
42+
)
4043
from torchtitan.protocols import BaseStateDictAdapter
4144
from torchtitan.tools.logging import logger
4245
from torchtitan.tools.utils import GarbageCollection
@@ -421,9 +424,22 @@ def dcp_load(
421424
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
422425
hf_state_dict = self.sd_adapter.to_hf(state_dict)
423426

427+
storage_reader = DeepSeekV3HuggingFaceStorageReader(
428+
path=checkpoint_id,
429+
block_size=128,
430+
thread_count=4
431+
)
432+
433+
# Use custom planner for key mapping between TorchTitan and HuggingFace formats
434+
planner = DeepSeekV3LoadPlanner()
435+
436+
# Let DCP handle the metadata reading internally
437+
# The planner will access the metadata in create_local_plan() after DCP calls read_metadata()
438+
424439
dcp.load(
425440
hf_state_dict,
426-
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
441+
storage_reader=storage_reader,
442+
planner=planner,
427443
)
428444

429445
state_dict = self.sd_adapter.from_hf(hf_state_dict)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""DeepSeek V3 model package."""
8+
9+
10+
from .metadata import DeepSeekV3Metadata
11+
from .deepseek_v3_storage_reader import DeepSeekV3HuggingFaceStorageReader
12+
from .deepseek_v3_planner import DeepSeekV3LoadPlanner
13+
from .state_dict_adapter import DeepSeekV3StateDictAdapter
14+
from . import key_mappings
15+
16+
__all__ = [
17+
"DeepSeekV3Metadata",
18+
"DeepSeekV3HuggingFaceStorageReader",
19+
"DeepSeekV3LoadPlanner",
20+
"DeepSeekV3StateDictAdapter",
21+
"key_mappings",
22+
]
23+

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
from dataclasses import dataclass, field
11-
from typing import Literal
1211

1312
from torch import nn
1413

@@ -28,28 +27,24 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
2827
Attributes:
2928
max_batch_size (int): Maximum batch size.
3029
max_seq_len (int): Maximum sequence length.
31-
dtype (Literal["bf16", "fp8"]): Data type for computations.
3230
vocab_size (int): Vocabulary size.
3331
dim (int): Model dimension.
3432
inter_dim (int): Intermediate dimension for MLP layers.
3533
moe_inter_dim (int): Intermediate dimension for MoE layers.
3634
n_layers (int): Number of transformer layers.
3735
n_dense_layers (int): Number of dense layers in the model.
3836
n_heads (int): Number of attention heads.
39-
n_routed_experts (int): Number of routed experts for MoE layers.
40-
n_shared_experts (int): Number of shared experts for MoE layers.
41-
n_activated_experts (int): Number of activated experts in MoE layers.
37+
norm_eps (float): Epsilon value used for RMSNorm.
38+
moe_args (MoEArgs): MoE configuration.
4239
n_expert_groups (int): Number of expert groups.
4340
n_limited_groups (int): Number of limited groups for MoE routing.
44-
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
45-
route_scale (float): Scaling factor for routing scores.
46-
use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers.
47-
load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers.
4841
q_lora_rank (int): LoRA rank for query projections.
4942
kv_lora_rank (int): LoRA rank for key-value projections.
5043
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
5144
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
5245
v_head_dim (int): Dimension for value projections.
46+
use_flex_attn (bool): Whether to use FlexAttention.
47+
attn_mask_type (str): Type of attention mask.
5348
original_seq_len (int): Original sequence length.
5449
rope_theta (float): Base for rotary positional encoding.
5550
rope_factor (float): Scaling factor for extended sequence lengths.
@@ -59,7 +54,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
5954

6055
max_batch_size: int = 8
6156
max_seq_len: int = 4096 * 4
62-
dtype: Literal["bf16", "fp8"] = "bf16"
6357
vocab_size: int = 102400
6458
dim: int = 2048
6559
inter_dim: int = 10944
@@ -111,15 +105,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
111105
"CP support for FlexAttention is still in progress."
112106
)
113107

114-
if (
115-
job_config.parallelism.pipeline_parallel_degree > 1
116-
and self.use_flex_attn
117-
and self.attn_mask_type == "block_causal"
118-
):
119-
raise RuntimeError(
120-
"PP + block causal FlexAttention support will be fixed soon."
121-
)
122-
123108
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
124109
"""
125110
Adopted from llama4 implementation.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
DeepSeek V3 Load Planner for DCP that handles grouped expert tensors.
9+
10+
This planner validates that grouped expert tensors can be formed from individual experts
11+
in the checkpoint before creating read items.
12+
"""
13+
14+
import re
15+
from typing import Any, List, Optional
16+
17+
from torch.distributed._tensor import DTensor
18+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, _create_read_items
19+
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
20+
from torch.distributed.checkpoint.planner import LoadPlan
21+
from torchtitan.models.deepseek_v3.model.metadata import (
22+
DeepSeekV3Metadata,
23+
)
24+
25+
class DeepSeekV3LoadPlanner(DefaultLoadPlanner):
26+
"""Load planner for DeepSeek V3 that handles grouped expert tensor validation."""
27+
28+
def __init__(self):
29+
"""Initialize the DeepSeek V3 load planner."""
30+
super().__init__()
31+
self.valid_grouped_experts = set()
32+
33+
def set_up_planner(
34+
self,
35+
state_dict: STATE_DICT_TYPE,
36+
metadata: Optional[DeepSeekV3Metadata] = None,
37+
is_coordinator: bool = False,
38+
) -> None:
39+
super().set_up_planner(state_dict, metadata.sd_metadata, is_coordinator)
40+
# Build cache of valid grouped expert FQNs once during setup
41+
self.metadata = metadata.sd_metadata
42+
self.io_metadata = metadata.io_metadata
43+
self.valid_grouped_experts = self._build_valid_grouped_experts()
44+
45+
def _build_valid_grouped_experts(self) -> set:
46+
"""Build cache of valid grouped expert FQNs from checkpoint metadata."""
47+
# Group individual experts by (layer, weight_type)
48+
experts_by_group = {}
49+
# Match only weight tensors, explicitly exclude scale tensors
50+
expert_pattern = r'model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(\w+)\.weight$'
51+
hf_to_tt_weight_map = {'gate_proj': 'w1', 'down_proj': 'w2', 'up_proj': 'w3'}
52+
53+
# Count total expert entries
54+
total_expert_entries = 0
55+
56+
for idx in self.io_metadata.storage_data.keys():
57+
match = re.match(expert_pattern, idx.fqn)
58+
if match:
59+
total_expert_entries += 1
60+
layer_idx, expert_idx, hf_weight_type = match.groups()
61+
tt_weight_type = hf_to_tt_weight_map.get(hf_weight_type)
62+
63+
if tt_weight_type:
64+
group_key = (layer_idx, tt_weight_type)
65+
if group_key not in experts_by_group:
66+
experts_by_group[group_key] = []
67+
experts_by_group[group_key].append(int(expert_idx))
68+
69+
# If no expert entries found, the checkpoint might not have individual experts
70+
# This could mean experts are already grouped or use a different naming pattern
71+
if total_expert_entries == 0:
72+
return set()
73+
74+
# Determine which grouped expert FQNs are valid
75+
# We just need to have at least one expert for each weight type in each layer
76+
valid_fqns = set()
77+
78+
if len(experts_by_group) == 0:
79+
return set()
80+
81+
for (layer_idx, tt_weight_type), expert_indices in experts_by_group.items():
82+
expert_indices = sorted(expert_indices)
83+
84+
# As long as we have at least one expert, we can create a grouped tensor
85+
if len(expert_indices) > 0:
86+
grouped_fqn = f"layers.{layer_idx}.moe.experts.{tt_weight_type}"
87+
valid_fqns.add(grouped_fqn)
88+
89+
return valid_fqns
90+
91+
def create_local_plan(self) -> LoadPlan:
92+
"""Create a local load plan starting from the model's state_dict."""
93+
requests = []
94+
95+
# Process each tensor in the model's state_dict
96+
for fqn, obj in self.state_dict.items():
97+
if self._is_grouped_expert_tensor(fqn) and fqn not in self.valid_grouped_experts:
98+
raise RuntimeError(f"Grouped expert tensor {fqn} cannot be loaded from checkpoint")
99+
100+
# Create read items for all tensors (both regular and grouped)
101+
self._validate_and_create_read_items(fqn, obj, self.metadata, requests)
102+
103+
return LoadPlan(requests)
104+
105+
def _validate_and_create_read_items(self, fqn: str, obj: Any, metadata: Any, requests: List) -> None:
106+
"""Validate tensor and add read items to requests."""
107+
if fqn not in metadata.state_dict_metadata:
108+
raise RuntimeError(f"Missing key in checkpoint metadata: {fqn}")
109+
110+
md = metadata.state_dict_metadata[fqn]
111+
112+
# Create read items (handle DTensor submesh)
113+
if isinstance(obj, DTensor):
114+
if obj.device_mesh.get_coordinate() is not None:
115+
requests += _create_read_items(fqn, md, obj)
116+
else:
117+
requests += _create_read_items(fqn, md, obj)
118+
119+
def _is_grouped_expert_tensor(self, fqn: str) -> bool:
120+
"""Check if this FQN represents a grouped expert tensor."""
121+
# Match grouped expert tensors but exclude shared expert tensors
122+
return 'moe.experts' in fqn
123+

0 commit comments

Comments
 (0)