4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from logging import raiseExceptions
7
8
import re
8
- from typing import Any
9
+ from typing import Any , Dict
9
10
10
11
import torch
11
12
13
+ from torchtitan .distributed .parallel_dims import ParallelDims
12
14
from torchtitan .protocols .state_dict_adapter import StateDictAdapter
13
15
14
16
from .args import DeepSeekV3ModelArgs
15
17
from .quantization import calculate_scale_shape , dequantize_from_fp8
16
18
19
+ from torch .distributed .tensor .placement_types import (
20
+ _StridedShard ,
21
+ Shard ,
22
+ Replicate
23
+ )
24
+
25
+ from torch .distributed .tensor import DTensor
26
+
17
27
18
28
class DeepSeekV3StateDictAdapter (StateDictAdapter ):
19
29
"""
20
30
StateDictAdapter for DeepSeekV3 model.
21
31
"""
22
32
23
- def __init__ (self , model_args : DeepSeekV3ModelArgs , hf_assets_path : str | None ):
33
+ def __init__ (self , model_args : DeepSeekV3ModelArgs , hf_assets_path : str | None , parallel_dims : ParallelDims ):
34
+ super ().__init__ (model_args , hf_assets_path , parallel_dims )
24
35
self .model_args = model_args
36
+ self .parallel_dims = parallel_dims
25
37
self .from_hf_map = {
26
38
"model.embed_tokens.weight" : "tok_embeddings.weight" ,
27
39
# Attention Module
@@ -52,7 +64,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None):
52
64
"lm_head.weight" : "output.weight" ,
53
65
}
54
66
55
- def _split_experts_weights (
67
+ def _split_experts_weight (
56
68
self , weight : torch .Tensor , n_experts : int
57
69
) -> list [torch .Tensor ]:
58
70
"""
@@ -84,6 +96,134 @@ def _concatenate_expert_weights(
84
96
85
97
return None
86
98
99
+ def _get_local_experts_weights (
100
+ self , abstract_key : str , layer_id : str , grouped_expert_weight : torch .Tensor
101
+ ) -> Dict [str , torch .Tensor ]:
102
+ """
103
+ Spliting the GroupedExperts weight and find the corresponding individual expert's weight in local tensor.
104
+
105
+ Potential experts weights shard placements:
106
+ - FSDP + EP when dp_mod_ep * ep <= num_experts:
107
+ - StridedShard(0)Shard(0)
108
+ - FSDP + EP when dp_mod_ep * ep <= num_experts:
109
+ - Shard(1)Shard(0)
110
+ - FSDP + ETP + EP when dp_mod_ep * ep <= num_experts:
111
+ - w1/w3: StridedShard(0)Shard(0)Shard(1)
112
+ - w2: StridedShard(0)Shard(0)Shard(2)
113
+ - FSDP + ETP + EP when dp_mod_ep * ep > num_experts:
114
+ - w1/w3: StridedShard(1)Shard(0)Shard(1)
115
+ - w2: Shard(1)Shard(0)Shard(2)
116
+ """
117
+ world_mesh = self .parallel_dims .world_mesh
118
+ num_experts = grouped_expert_weight .shape [0 ]
119
+
120
+ # Matching DTensor sharding placement and device mesh dims,
121
+ # find the dtensor dims that shard on dim-0 (num_experts dim)
122
+ original_placements = grouped_expert_weight .placements
123
+ world_mesh_names = []
124
+ dim_0_placements = []
125
+ for i , name in enumerate (world_mesh .mesh_dim_names ):
126
+ placement = original_placements [i ]
127
+ if placement .dim == 0 :
128
+ world_mesh_names .append (name )
129
+ dim_0_placements .append (placement )
130
+
131
+ start_index , end_index = None , None
132
+ # StridedShard(0)Shard(0)
133
+ if len (dim_0_placements ) == 2 :
134
+ assert isinstance (dim_0_placements [0 ], _StridedShard )
135
+ strided_shard_mesh = world_mesh [world_mesh_names [0 ]]
136
+ strided_degree , strided_rank = strided_shard_mesh .size (), strided_shard_mesh .get_local_rank ()
137
+ shard_mesh = world_mesh [world_mesh_names [1 ]]
138
+ shard_degree , shard_rank = shard_mesh .size (), shard_mesh .get_local_rank ()
139
+ start_index , end_index = self ._get_strided_shard_shard_slice (strided_degree , strided_rank , shard_degree , shard_rank , num_experts )
140
+ # Shard(0)
141
+ elif len (dim_0_placements ) == 1 :
142
+ assert not isinstance (dim_0_placements [0 ], _StridedShard )
143
+ shard_mesh = world_mesh [world_mesh_names [0 ]]
144
+ shard_degree , shard_rank = shard_mesh .size (), shard_mesh .get_local_rank ()
145
+ block_size = num_experts // shard_degree
146
+ if block_size * shard_degree != num_experts :
147
+ raise ValueError ("Not supported. num_experts can not be evenly divided by Shard(0) dimension degree." )
148
+
149
+ start_index = block_size * shard_rank
150
+ end_index = start_index + block_size
151
+ else :
152
+ raise NotImplementedError (f"The DTensor placements { original_placements } for GroupedExperts is not supported in StateDictAdapter" )
153
+
154
+ # Calculate the new placement for individual expert weights
155
+ new_placements = []
156
+ for i , name in enumerate (world_mesh .mesh_dim_names ):
157
+ placement = original_placements [i ]
158
+ if placement .dim == 0 :
159
+ new_placements .append (Replicate ())
160
+ elif isinstance (placement , Shard ):
161
+ # Individual expert weight has only 2 dimensions
162
+ new_placements .append (Shard (placement .dim - 1 ))
163
+ elif isinstance (placement , _StridedShard ):
164
+ new_placements .append (_StridedShard (placement .dim - 1 , placement .split_factor ))
165
+ else :
166
+ raise ValueError ("Not supported new placements!" )
167
+ print (f"Original placements: { original_placements } , new placements { new_placements } " )
168
+
169
+ assert isinstance (grouped_expert_weight , DTensor ), "GroupedExperts weight is not a DTensor"
170
+ local_grouped_weights = grouped_expert_weight ._local_tensor
171
+ assert local_grouped_weights .shape [0 ] == int (end_index - start_index ), "Local tensor shape mismatch!"
172
+
173
+ # Create new DTensor for each individual expert weights
174
+ local_expert_fqn = {}
175
+ for expert_id in range (start_index , end_index ):
176
+ new_key = abstract_key .format (layer_id , expert_id )
177
+ new_value = local_grouped_weights [expert_id - start_index , :, :].squeeze
178
+ local_expert_fqn [new_key ] = DTensor .from_local (new_value , world_mesh , new_placements , run_check = False )
179
+
180
+ return local_expert_fqn
181
+
182
+
183
+ def _get_strided_shard_shard_slice (
184
+ self ,
185
+ strided_shard_dim_degree : int ,
186
+ strided_shard_dim_rank : int ,
187
+ shard_dim_degree : int ,
188
+ shard_dim_rank : int ,
189
+ dim_size_to_split : int ,
190
+ ) -> tuple [int , int ]:
191
+ """
192
+ Given a [StridedShard(dim=i), Shard(dim=i)] placement, caculate the start index
193
+ and end index on dim-i for GPU rank (strided_shard_dim_degree, shard_dim_rank)
194
+
195
+ GPU Layout (strided_shard_rank, shard_rank):
196
+
197
+ StridedShard Rank Shard rank
198
+ ┌─────────────────┐
199
+ 0 │ GPU(0, 0) │ 0
200
+ ────┼─────────────────┤
201
+ 1 │ GPU(1, 0) │
202
+ ────┼─────────────────┤
203
+ 2 │ GPU(2, 0) │
204
+ ──────┼─────────────────┼────
205
+ 0 │ GPU(0, 1) │ 1
206
+ ────┼─────────────────┤
207
+ 1 │ GPU(1, 1) │
208
+ ────┼─────────────────┤
209
+ 2 │ GPU(2, 1) │
210
+ └─────────────────┘
211
+
212
+ Calulate the start_index from inner dimesion (Shard(dim=i)) to outer demension (StridedShard(dim=i)).
213
+ """
214
+
215
+ block_size = dim_size_to_split // (strided_shard_dim_degree * shard_dim_degree )
216
+
217
+ # Error out if can not evenly divded
218
+ if block_size * (strided_shard_dim_degree * shard_dim_degree ) != dim_size_to_split :
219
+ raise ValueError (f"Not supported split for strided_shard_dim_degree { strided_shard_dim_degree } , shard_dim_degree { shard_dim_degree } , dim_size_to_split { dim_size_to_split } " )
220
+
221
+ start_index = block_size * (strided_shard_dim_degree * shard_dim_rank + strided_shard_dim_rank )
222
+ end_index = start_index + block_size
223
+
224
+ return start_index , end_index
225
+
226
+
87
227
def _dequantize (self , state_dict : dict [str , Any ]) -> dict [str , Any ]:
88
228
"""
89
229
Dequantize the weights from float8 to float32.
@@ -149,14 +289,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
149
289
layer_num = re .search (r"\d+" , key ).group (0 )
150
290
new_abstract_key = to_hf_map [abstract_key ]
151
291
152
- # Split expert weights into separate expert weights
153
- split_values = self ._split_experts_weights (
154
- value , self .model_args .moe_args .num_experts
292
+ # # Split expert weights into separate expert weights
293
+ # split_values = self._split_experts_weights(
294
+ # value, self.model_args.moe_args.num_experts
295
+ # )
296
+ local_expert_fqn = self ._get_local_experts_weights (
297
+ new_abstract_key , layer_num , value
155
298
)
299
+ print (f"groupedWeight placements { value .placements } , local experts keys { local_expert_fqn .keys ()} " )
156
300
157
- for expert_num in range (0 , self .model_args .moe_args .num_experts ):
158
- new_key = new_abstract_key .format (layer_num , expert_num )
159
- hf_state_dict [new_key ] = split_values [expert_num ].squeeze ()
301
+ hf_state_dict .update (local_expert_fqn )
160
302
161
303
elif "layers" in key :
162
304
abstract_key = re .sub (r"(\d+)" , "{}" , key , count = 1 )
@@ -169,9 +311,11 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
169
311
new_key = to_hf_map [key ]
170
312
hf_state_dict [new_key ] = value
171
313
314
+ # Prepare for dequantization
172
315
hf_state_dict_with_scale_inv = self ._add_quantization_scale_inv_tensors (
173
316
hf_state_dict
174
317
)
318
+ print (f"[to_hf] state_dict keys before return: { hf_state_dict_with_scale_inv .keys ()} " )
175
319
return hf_state_dict_with_scale_inv
176
320
177
321
def from_hf (self , hf_state_dict : dict [str , Any ]) -> dict [str , Any ]:
0 commit comments