Skip to content

Commit 81f698a

Browse files
committed
add to_hf
1 parent 17ef753 commit 81f698a

File tree

7 files changed

+184
-24
lines changed

7 files changed

+184
-24
lines changed

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@
135135
dim=7168,
136136
inter_dim=18432,
137137
moe_inter_dim=2048,
138-
n_layers=61,
138+
n_layers=4,
139139
n_dense_layers=3,
140140
n_heads=128,
141141
moe_args=MoEArgs(

torchtitan/models/deepseek_v3/model/state_dict_adapter.py

Lines changed: 153 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,36 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from logging import raiseExceptions
78
import re
8-
from typing import Any
9+
from typing import Any, Dict
910

1011
import torch
1112

13+
from torchtitan.distributed.parallel_dims import ParallelDims
1214
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
1315

1416
from .args import DeepSeekV3ModelArgs
1517
from .quantization import calculate_scale_shape, dequantize_from_fp8
1618

19+
from torch.distributed.tensor.placement_types import (
20+
_StridedShard,
21+
Shard,
22+
Replicate
23+
)
24+
25+
from torch.distributed.tensor import DTensor
26+
1727

1828
class DeepSeekV3StateDictAdapter(StateDictAdapter):
1929
"""
2030
StateDictAdapter for DeepSeekV3 model.
2131
"""
2232

23-
def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None):
33+
def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims):
34+
super().__init__(model_args, hf_assets_path, parallel_dims)
2435
self.model_args = model_args
36+
self.parallel_dims = parallel_dims
2537
self.from_hf_map = {
2638
"model.embed_tokens.weight": "tok_embeddings.weight",
2739
# Attention Module
@@ -52,7 +64,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None):
5264
"lm_head.weight": "output.weight",
5365
}
5466

55-
def _split_experts_weights(
67+
def _split_experts_weight(
5668
self, weight: torch.Tensor, n_experts: int
5769
) -> list[torch.Tensor]:
5870
"""
@@ -84,6 +96,134 @@ def _concatenate_expert_weights(
8496

8597
return None
8698

99+
def _get_local_experts_weights(
100+
self, abstract_key: str, layer_id: str, grouped_expert_weight: torch.Tensor
101+
) -> Dict[str, torch.Tensor]:
102+
"""
103+
Spliting the GroupedExperts weight and find the corresponding individual expert's weight in local tensor.
104+
105+
Potential experts weights shard placements:
106+
- FSDP + EP when dp_mod_ep * ep <= num_experts:
107+
- StridedShard(0)Shard(0)
108+
- FSDP + EP when dp_mod_ep * ep <= num_experts:
109+
- Shard(1)Shard(0)
110+
- FSDP + ETP + EP when dp_mod_ep * ep <= num_experts:
111+
- w1/w3: StridedShard(0)Shard(0)Shard(1)
112+
- w2: StridedShard(0)Shard(0)Shard(2)
113+
- FSDP + ETP + EP when dp_mod_ep * ep > num_experts:
114+
- w1/w3: StridedShard(1)Shard(0)Shard(1)
115+
- w2: Shard(1)Shard(0)Shard(2)
116+
"""
117+
world_mesh = self.parallel_dims.world_mesh
118+
num_experts = grouped_expert_weight.shape[0]
119+
120+
# Matching DTensor sharding placement and device mesh dims,
121+
# find the dtensor dims that shard on dim-0 (num_experts dim)
122+
original_placements = grouped_expert_weight.placements
123+
world_mesh_names = []
124+
dim_0_placements = []
125+
for i, name in enumerate(world_mesh.mesh_dim_names):
126+
placement = original_placements[i]
127+
if placement.dim == 0:
128+
world_mesh_names.append(name)
129+
dim_0_placements.append(placement)
130+
131+
start_index, end_index = None, None
132+
# StridedShard(0)Shard(0)
133+
if len(dim_0_placements) == 2:
134+
assert isinstance(dim_0_placements[0], _StridedShard)
135+
strided_shard_mesh = world_mesh[world_mesh_names[0]]
136+
strided_degree, strided_rank = strided_shard_mesh.size(), strided_shard_mesh.get_local_rank()
137+
shard_mesh = world_mesh[world_mesh_names[1]]
138+
shard_degree, shard_rank = shard_mesh.size(), shard_mesh.get_local_rank()
139+
start_index, end_index = self._get_strided_shard_shard_slice(strided_degree, strided_rank, shard_degree, shard_rank, num_experts)
140+
# Shard(0)
141+
elif len(dim_0_placements) == 1:
142+
assert not isinstance(dim_0_placements[0], _StridedShard)
143+
shard_mesh = world_mesh[world_mesh_names[0]]
144+
shard_degree, shard_rank = shard_mesh.size(), shard_mesh.get_local_rank()
145+
block_size = num_experts // shard_degree
146+
if block_size * shard_degree != num_experts:
147+
raise ValueError("Not supported. num_experts can not be evenly divided by Shard(0) dimension degree.")
148+
149+
start_index = block_size * shard_rank
150+
end_index = start_index + block_size
151+
else:
152+
raise NotImplementedError(f"The DTensor placements {original_placements} for GroupedExperts is not supported in StateDictAdapter")
153+
154+
# Calculate the new placement for individual expert weights
155+
new_placements = []
156+
for i, name in enumerate(world_mesh.mesh_dim_names):
157+
placement = original_placements[i]
158+
if placement.dim == 0:
159+
new_placements.append(Replicate())
160+
elif isinstance(placement, Shard):
161+
# Individual expert weight has only 2 dimensions
162+
new_placements.append(Shard(placement.dim-1))
163+
elif isinstance(placement, _StridedShard):
164+
new_placements.append(_StridedShard(placement.dim-1, placement.split_factor))
165+
else:
166+
raise ValueError("Not supported new placements!")
167+
print(f"Original placements: {original_placements}, new placements {new_placements}")
168+
169+
assert isinstance(grouped_expert_weight, DTensor), "GroupedExperts weight is not a DTensor"
170+
local_grouped_weights = grouped_expert_weight._local_tensor
171+
assert local_grouped_weights.shape[0] == int(end_index - start_index), "Local tensor shape mismatch!"
172+
173+
# Create new DTensor for each individual expert weights
174+
local_expert_fqn = {}
175+
for expert_id in range(start_index, end_index):
176+
new_key = abstract_key.format(layer_id, expert_id)
177+
new_value = local_grouped_weights[expert_id - start_index, :, :].squeeze
178+
local_expert_fqn[new_key] = DTensor.from_local(new_value, world_mesh, new_placements, run_check=False)
179+
180+
return local_expert_fqn
181+
182+
183+
def _get_strided_shard_shard_slice(
184+
self,
185+
strided_shard_dim_degree: int,
186+
strided_shard_dim_rank: int,
187+
shard_dim_degree: int,
188+
shard_dim_rank: int,
189+
dim_size_to_split: int,
190+
) -> tuple[int, int]:
191+
"""
192+
Given a [StridedShard(dim=i), Shard(dim=i)] placement, caculate the start index
193+
and end index on dim-i for GPU rank (strided_shard_dim_degree, shard_dim_rank)
194+
195+
GPU Layout (strided_shard_rank, shard_rank):
196+
197+
StridedShard Rank Shard rank
198+
┌─────────────────┐
199+
0 │ GPU(0, 0) │ 0
200+
────┼─────────────────┤
201+
1 │ GPU(1, 0) │
202+
────┼─────────────────┤
203+
2 │ GPU(2, 0) │
204+
──────┼─────────────────┼────
205+
0 │ GPU(0, 1) │ 1
206+
────┼─────────────────┤
207+
1 │ GPU(1, 1) │
208+
────┼─────────────────┤
209+
2 │ GPU(2, 1) │
210+
└─────────────────┘
211+
212+
Calulate the start_index from inner dimesion (Shard(dim=i)) to outer demension (StridedShard(dim=i)).
213+
"""
214+
215+
block_size = dim_size_to_split // (strided_shard_dim_degree * shard_dim_degree)
216+
217+
# Error out if can not evenly divded
218+
if block_size * (strided_shard_dim_degree * shard_dim_degree) != dim_size_to_split:
219+
raise ValueError(f"Not supported split for strided_shard_dim_degree {strided_shard_dim_degree}, shard_dim_degree {shard_dim_degree}, dim_size_to_split {dim_size_to_split}")
220+
221+
start_index = block_size * (strided_shard_dim_degree * shard_dim_rank + strided_shard_dim_rank)
222+
end_index = start_index + block_size
223+
224+
return start_index, end_index
225+
226+
87227
def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
88228
"""
89229
Dequantize the weights from float8 to float32.
@@ -149,14 +289,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
149289
layer_num = re.search(r"\d+", key).group(0)
150290
new_abstract_key = to_hf_map[abstract_key]
151291

152-
# Split expert weights into separate expert weights
153-
split_values = self._split_experts_weights(
154-
value, self.model_args.moe_args.num_experts
292+
# # Split expert weights into separate expert weights
293+
# split_values = self._split_experts_weights(
294+
# value, self.model_args.moe_args.num_experts
295+
# )
296+
local_expert_fqn = self._get_local_experts_weights(
297+
new_abstract_key, layer_num, value
155298
)
299+
print(f"groupedWeight placements {value.placements}, local experts keys {local_expert_fqn.keys()}")
156300

157-
for expert_num in range(0, self.model_args.moe_args.num_experts):
158-
new_key = new_abstract_key.format(layer_num, expert_num)
159-
hf_state_dict[new_key] = split_values[expert_num].squeeze()
301+
hf_state_dict.update(local_expert_fqn)
160302

161303
elif "layers" in key:
162304
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
@@ -169,9 +311,11 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
169311
new_key = to_hf_map[key]
170312
hf_state_dict[new_key] = value
171313

314+
# Prepare for dequantization
172315
hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors(
173316
hf_state_dict
174317
)
318+
print(f"[to_hf] state_dict keys before return: {hf_state_dict_with_scale_inv.keys()}")
175319
return hf_state_dict_with_scale_inv
176320

177321
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4747
data_parallel_replicate_degree = 1
4848
data_parallel_shard_degree = -1
4949
fsdp_reshard_after_forward = "default" # default / never / always
50-
tensor_parallel_degree = 1
50+
tensor_parallel_degree = 4
5151
enable_async_tensor_parallel = false
5252
pipeline_parallel_degree = 1
5353
pipeline_parallel_schedule = "1F1B"
5454
context_parallel_degree = 1
55-
expert_parallel_degree = 1
56-
expert_tensor_parallel_degree = 1
55+
expert_parallel_degree = 2
56+
expert_tensor_parallel_degree = 4
5757

5858
[checkpoint]
5959
enable = false

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,30 @@ min_lr_factor = 0.1
3838
local_batch_size = 4
3939
seq_len = 4096
4040
max_norm = 1.0 # grad norm clipping
41-
steps = 10_000
41+
steps = 10
4242
compile = false
4343
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4444

4545
[parallelism]
4646
data_parallel_replicate_degree = 1
4747
data_parallel_shard_degree = -1
4848
fsdp_reshard_after_forward = "default" # default / never / always
49-
tensor_parallel_degree = 1
49+
tensor_parallel_degree = 2
5050
enable_async_tensor_parallel = false
5151
pipeline_parallel_degree = 1
5252
pipeline_parallel_schedule = "Interleaved1F1B"
53-
expert_parallel_degree = 1
54-
expert_tensor_parallel_degree = 1
53+
expert_parallel_degree = 2
54+
expert_tensor_parallel_degree = 2
5555

5656
[checkpoint]
5757
enable = false
5858
folder = "checkpoint"
59-
interval = 500
59+
interval = 10
6060
last_save_model_only = true
6161
export_dtype = "float32"
6262
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"
63+
initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base"
64+
initial_load_in_hf=true
6365

6466
[activation_checkpoint]
6567
mode = "selective" # ["none", "selective", "full"]

torchtitan/models/llama3/model/state_dict_adapter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010

1111
logger = logging.getLogger()
1212

13+
from torchtitan.distributed.parallel_dims import ParallelDims
1314
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
1415

1516
from .args import TransformerModelArgs
1617

1718

1819
class Llama3StateDictAdapter(StateDictAdapter):
19-
def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None):
20-
super().__init__(model_args, hf_assets_path)
20+
def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims):
21+
super().__init__(model_args, hf_assets_path, parallel_dims)
2122

2223
self.model_args = model_args
2324
self.hf_assets_path = hf_assets_path

torchtitan/protocols/state_dict_adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from abc import ABC, abstractmethod
1212
from typing import Any
1313

14+
from torchtitan.distributed.parallel_dims import ParallelDims
15+
1416
logger = logging.getLogger()
1517

1618
from .model import BaseModelArgs
@@ -27,7 +29,7 @@ class BaseStateDictAdapter(ABC):
2729
"""
2830

2931
@abstractmethod
30-
def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None):
32+
def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims):
3133
pass
3234

3335
@abstractmethod
@@ -58,7 +60,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
5860
class StateDictAdapter(BaseStateDictAdapter):
5961
"""State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping"""
6062

61-
def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None):
63+
def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None, parallel_dims: ParallelDims):
6264
if hf_assets_path:
6365
mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json")
6466
try:

torchtitan/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ensure_pp_loss_visible,
2424
)
2525
from torchtitan.config import ConfigManager, JobConfig
26-
from torchtitan.distributed import ParallelDims, utils as dist_utils
26+
from torchtitan.distributed import ParallelDims, parallel_dims, utils as dist_utils
2727
from torchtitan.models.attention import init_attention_mask
2828
from torchtitan.protocols.model_converter import build_model_converters
2929
from torchtitan.tools import utils
@@ -311,7 +311,7 @@ def __init__(self, job_config: JobConfig):
311311
checkpoint_config=job_config.checkpoint,
312312
sd_adapter=(
313313
self.train_spec.state_dict_adapter(
314-
model_args, job_config.model.hf_assets_path
314+
model_args, job_config.model.hf_assets_path, self.parallel_dims
315315
)
316316
if self.train_spec.state_dict_adapter
317317
else None
@@ -539,6 +539,17 @@ def train_step(
539539
def train(self):
540540
job_config = self.job_config
541541

542+
# Following hacky print only works for debug_model
543+
# w1 = self.model_parts[0].layers["1"].moe.experts.w1
544+
# w2 = self.model_parts[0].layers["1"].moe.experts.w2
545+
# w3 = self.model_parts[0].layers["1"].moe.experts.w3
546+
547+
# logger.info(f"w1 placements is: {w1.placements}, {type(w1.placements)}")
548+
# logger.info(f"w2 placements is: {w2.placements}")
549+
# logger.info(f"w3 placements is: {w3.placements}")
550+
# logger.info(f"device mesh: {self.parallel_dims.world_mesh}, {self.parallel_dims.world_mesh.mesh_dim_names} {self.parallel_dims.world_mesh['dp_shard']}")
551+
552+
542553
self.checkpointer.load(step=job_config.checkpoint.load_step)
543554
logger.info(f"Training starts at step {self.step + 1}")
544555

0 commit comments

Comments
 (0)