15
15
16
16
from torch .distributed .tensor .placement_types import _StridedShard , Replicate , Shard
17
17
18
- from torchtitan .distributed .parallel_dims import ParallelDims
19
18
from torchtitan .protocols .state_dict_adapter import StateDictAdapter
20
- from torchtitan .tools .logging import logger
21
19
22
20
from .args import DeepSeekV3ModelArgs
23
21
from .quantization import calculate_scale_shape , dequantize_from_fp8
@@ -134,9 +132,6 @@ def _caculate_indices_from_placements(
134
132
# Find all the device mesh dimensios that shard on dim-i
135
133
for i , name in enumerate (device_mesh .mesh_dim_names ):
136
134
placement = dtensor_placements [i ]
137
- print (
138
- f"In _caculate_indices_from_placements, placement dim = { placement .dim } { type (placement .dim )} , { dim } { type (dim )} "
139
- )
140
135
if placement .dim == dim :
141
136
mesh_names .append (name )
142
137
dim_i_placements .append (placement )
@@ -161,8 +156,6 @@ def _caculate_indices_from_placements(
161
156
strided_degree , strided_rank , shard_degree , shard_rank , dim_size
162
157
)
163
158
164
- return start_index , end_index
165
-
166
159
elif len (dim_i_placements ) == 1 :
167
160
# Handle single Shard(i) case
168
161
assert not isinstance (
@@ -182,8 +175,6 @@ def _caculate_indices_from_placements(
182
175
start_index = block_size * shard_rank
183
176
end_index = start_index + block_size
184
177
185
- return start_index , end_index
186
-
187
178
elif len (dim_i_placements ) == 0 :
188
179
# No need to split on this dimension
189
180
return start_index , end_index
@@ -193,6 +184,9 @@ def _caculate_indices_from_placements(
193
184
f"Unsupported DTensor placements for GroupedExperts: { dtensor_placements } { dim_i_placements } { mesh_names } "
194
185
)
195
186
187
+ return start_index , end_index
188
+
189
+
196
190
def _get_local_experts_weights (
197
191
self ,
198
192
abstract_key : str ,
@@ -331,7 +325,6 @@ def _concatenate_local_expert_weights(
331
325
"""
332
326
Concatenate the weights of separate experts into GroupedExperts weights.
333
327
"""
334
- logger .info (f"Concatenating for key { abstract_key } " )
335
328
for layer in expert_weights_by_layer .keys ():
336
329
# If we have all the experts for this abstract_key, concatenate them
337
330
experts = expert_weights_by_layer [layer ][abstract_key ]
@@ -363,11 +356,7 @@ def _concatenate_local_expert_weights(
363
356
if not expert_weights_by_layer [layer ]:
364
357
del expert_weights_by_layer [layer ]
365
358
366
- logger .info (f"Concatenated for key { abstract_key } at layer { layer } " )
367
-
368
359
return stacked_dtensor
369
- else :
370
- logger .info ("no enough experts to concate" )
371
360
372
361
return None
373
362
@@ -475,9 +464,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
475
464
2. Convert between the HF shape and the torchtitan shape.
476
465
3. Concate separate expert's wegiht into GroupedExperts' weight.
477
466
"""
478
- print (
479
- f"At the beginning of from_hf, the loaded state_dict is { hf_state_dict .keys ()} "
480
- )
467
+
481
468
# dequantize the tensor in state_dict and remove the scale_inv tensor
482
469
483
470
hf_state_dict = self ._dequantize (hf_state_dict )
0 commit comments