Skip to content

Commit e0717a5

Browse files
msaurabhsaumishr
authored andcommitted
DCP: Dequantization and expert grouping for DSv3
1 parent ad06609 commit e0717a5

File tree

7 files changed

+1290
-120
lines changed

7 files changed

+1290
-120
lines changed

torchtitan/components/checkpoint.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
import torch.distributed as dist
2020
import torch.distributed.checkpoint as dcp
2121
import torch.nn as nn
22-
from torch.distributed.checkpoint import (
23-
HuggingFaceStorageReader,
24-
HuggingFaceStorageWriter,
25-
)
22+
from torch.distributed.checkpoint import HuggingFaceStorageWriter
2623
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
2724
from torch.distributed.checkpoint.state_dict import (
2825
get_model_state_dict,
@@ -37,6 +34,12 @@
3734
from torchtitan.components.lr_scheduler import LRSchedulersContainer
3835
from torchtitan.components.optimizer import OptimizersContainer
3936
from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP
37+
from torchtitan.models.deepseek_v3.model.deepseek_v3_storage_reader import (
38+
DeepSeekV3HuggingFaceStorageReader,
39+
)
40+
from torchtitan.models.deepseek_v3.model.deepseek_v3_planner import (
41+
DeepSeekV3LoadPlanner,
42+
)
4043
from torchtitan.protocols import BaseStateDictAdapter
4144
from torchtitan.tools.logging import logger
4245
from torchtitan.tools.utils import GarbageCollection
@@ -412,9 +415,22 @@ def dcp_load(
412415
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
413416
hf_state_dict = self.sd_adapter.to_hf(state_dict)
414417

418+
storage_reader = DeepSeekV3HuggingFaceStorageReader(
419+
path=checkpoint_id,
420+
block_size=128,
421+
thread_count=4
422+
)
423+
424+
# Use custom planner for key mapping between TorchTitan and HuggingFace formats
425+
planner = DeepSeekV3LoadPlanner()
426+
427+
# Let DCP handle the metadata reading internally
428+
# The planner will access the metadata in create_local_plan() after DCP calls read_metadata()
429+
415430
dcp.load(
416431
hf_state_dict,
417-
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
432+
storage_reader=storage_reader,
433+
planner=planner,
418434
)
419435

420436
state_dict = self.sd_adapter.from_hf(hf_state_dict)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""DeepSeek V3 model package."""
8+
9+
10+
from .metadata import DeepSeekV3Metadata
11+
from .deepseek_v3_storage_reader import DeepSeekV3HuggingFaceStorageReader
12+
from .deepseek_v3_planner import DeepSeekV3LoadPlanner
13+
from .state_dict_adapter import DeepSeekV3StateDictAdapter
14+
from . import key_mappings
15+
16+
__all__ = [
17+
"DeepSeekV3Metadata",
18+
"DeepSeekV3HuggingFaceStorageReader",
19+
"DeepSeekV3LoadPlanner",
20+
"DeepSeekV3StateDictAdapter",
21+
"key_mappings",
22+
]
23+
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
DeepSeek V3 Load Planner for DCP that handles grouped expert tensors.
9+
10+
This planner validates that grouped expert tensors can be formed from individual experts
11+
in the checkpoint before creating read items.
12+
"""
13+
14+
import re
15+
from typing import Any, List, Optional
16+
17+
from torch.distributed._tensor import DTensor
18+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, _create_read_items
19+
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
20+
from torch.distributed.checkpoint.planner import LoadPlan
21+
from torchtitan.models.deepseek_v3.model.metadata import (
22+
DeepSeekV3Metadata,
23+
)
24+
25+
class DeepSeekV3LoadPlanner(DefaultLoadPlanner):
26+
"""Load planner for DeepSeek V3 that handles grouped expert tensor validation."""
27+
28+
def __init__(self):
29+
"""Initialize the DeepSeek V3 load planner."""
30+
super().__init__()
31+
self.valid_grouped_experts = set()
32+
33+
def set_up_planner(
34+
self,
35+
state_dict: STATE_DICT_TYPE,
36+
metadata: Optional[DeepSeekV3Metadata] = None,
37+
is_coordinator: bool = False,
38+
) -> None:
39+
super().set_up_planner(state_dict, metadata.sd_metadata, is_coordinator)
40+
# Build cache of valid grouped expert FQNs once during setup
41+
self.metadata = metadata.sd_metadata
42+
self.io_metadata = metadata.io_metadata
43+
self.valid_grouped_experts = self._build_valid_grouped_experts()
44+
45+
def _build_valid_grouped_experts(self) -> set:
46+
"""Build cache of valid grouped expert FQNs from checkpoint metadata."""
47+
# Group individual experts by (layer, weight_type)
48+
experts_by_group = {}
49+
# Match only weight tensors, explicitly exclude scale tensors
50+
expert_pattern = r'model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(\w+)\.weight$'
51+
hf_to_tt_weight_map = {'gate_proj': 'w1', 'down_proj': 'w2', 'up_proj': 'w3'}
52+
53+
# Count total expert entries
54+
total_expert_entries = 0
55+
56+
for idx in self.io_metadata.storage_data.keys():
57+
match = re.match(expert_pattern, idx.fqn)
58+
if match:
59+
total_expert_entries += 1
60+
layer_idx, expert_idx, hf_weight_type = match.groups()
61+
tt_weight_type = hf_to_tt_weight_map.get(hf_weight_type)
62+
63+
if tt_weight_type:
64+
group_key = (layer_idx, tt_weight_type)
65+
if group_key not in experts_by_group:
66+
experts_by_group[group_key] = []
67+
experts_by_group[group_key].append(int(expert_idx))
68+
69+
# If no expert entries found, the checkpoint might not have individual experts
70+
# This could mean experts are already grouped or use a different naming pattern
71+
if total_expert_entries == 0:
72+
return set()
73+
74+
# Determine which grouped expert FQNs are valid
75+
# We just need to have at least one expert for each weight type in each layer
76+
valid_fqns = set()
77+
78+
if len(experts_by_group) == 0:
79+
return set()
80+
81+
for (layer_idx, tt_weight_type), expert_indices in experts_by_group.items():
82+
expert_indices = sorted(expert_indices)
83+
84+
# As long as we have at least one expert, we can create a grouped tensor
85+
if len(expert_indices) > 0:
86+
grouped_fqn = f"layers.{layer_idx}.moe.experts.{tt_weight_type}"
87+
valid_fqns.add(grouped_fqn)
88+
89+
return valid_fqns
90+
91+
def create_local_plan(self) -> LoadPlan:
92+
"""Create a local load plan starting from the model's state_dict."""
93+
requests = []
94+
95+
# Process each tensor in the model's state_dict
96+
for fqn, obj in self.state_dict.items():
97+
if self._is_grouped_expert_tensor(fqn) and fqn not in self.valid_grouped_experts:
98+
raise RuntimeError(f"Grouped expert tensor {fqn} cannot be loaded from checkpoint")
99+
100+
# Create read items for all tensors (both regular and grouped)
101+
self._validate_and_create_read_items(fqn, obj, self.metadata, requests)
102+
103+
return LoadPlan(requests)
104+
105+
def _validate_and_create_read_items(self, fqn: str, obj: Any, metadata: Any, requests: List) -> None:
106+
"""Validate tensor and add read items to requests."""
107+
if fqn not in metadata.state_dict_metadata:
108+
raise RuntimeError(f"Missing key in checkpoint metadata: {fqn}")
109+
110+
md = metadata.state_dict_metadata[fqn]
111+
112+
# Create read items (handle DTensor submesh)
113+
if isinstance(obj, DTensor):
114+
if obj.device_mesh.get_coordinate() is not None:
115+
requests += _create_read_items(fqn, md, obj)
116+
else:
117+
requests += _create_read_items(fqn, md, obj)
118+
119+
def _is_grouped_expert_tensor(self, fqn: str) -> bool:
120+
"""Check if this FQN represents a grouped expert tensor."""
121+
# Match grouped expert tensors but exclude shared expert tensors
122+
return 'moe.experts' in fqn
123+

0 commit comments

Comments
 (0)