Skip to content

Commit 4bd34fb

Browse files
committed
debugging
1 parent e7c39f6 commit 4bd34fb

File tree

2 files changed

+5
-19
lines changed

2 files changed

+5
-19
lines changed

torchtitan/models/deepseek_v3/model/state_dict_adapter.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616
from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard
1717

18-
from torchtitan.distributed.parallel_dims import ParallelDims
1918
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
20-
from torchtitan.tools.logging import logger
2119

2220
from .args import DeepSeekV3ModelArgs
2321
from .quantization import calculate_scale_shape, dequantize_from_fp8
@@ -134,9 +132,6 @@ def _caculate_indices_from_placements(
134132
# Find all the device mesh dimensios that shard on dim-i
135133
for i, name in enumerate(device_mesh.mesh_dim_names):
136134
placement = dtensor_placements[i]
137-
print(
138-
f"In _caculate_indices_from_placements, placement dim = {placement.dim} {type(placement.dim)}, {dim} {type(dim)}"
139-
)
140135
if placement.dim == dim:
141136
mesh_names.append(name)
142137
dim_i_placements.append(placement)
@@ -161,8 +156,6 @@ def _caculate_indices_from_placements(
161156
strided_degree, strided_rank, shard_degree, shard_rank, dim_size
162157
)
163158

164-
return start_index, end_index
165-
166159
elif len(dim_i_placements) == 1:
167160
# Handle single Shard(i) case
168161
assert not isinstance(
@@ -182,8 +175,6 @@ def _caculate_indices_from_placements(
182175
start_index = block_size * shard_rank
183176
end_index = start_index + block_size
184177

185-
return start_index, end_index
186-
187178
elif len(dim_i_placements) == 0:
188179
# No need to split on this dimension
189180
return start_index, end_index
@@ -193,6 +184,9 @@ def _caculate_indices_from_placements(
193184
f"Unsupported DTensor placements for GroupedExperts: {dtensor_placements} {dim_i_placements} {mesh_names}"
194185
)
195186

187+
return start_index, end_index
188+
189+
196190
def _get_local_experts_weights(
197191
self,
198192
abstract_key: str,
@@ -331,7 +325,6 @@ def _concatenate_local_expert_weights(
331325
"""
332326
Concatenate the weights of separate experts into GroupedExperts weights.
333327
"""
334-
logger.info(f"Concatenating for key {abstract_key} ")
335328
for layer in expert_weights_by_layer.keys():
336329
# If we have all the experts for this abstract_key, concatenate them
337330
experts = expert_weights_by_layer[layer][abstract_key]
@@ -363,11 +356,7 @@ def _concatenate_local_expert_weights(
363356
if not expert_weights_by_layer[layer]:
364357
del expert_weights_by_layer[layer]
365358

366-
logger.info(f"Concatenated for key {abstract_key} at layer {layer}")
367-
368359
return stacked_dtensor
369-
else:
370-
logger.info("no enough experts to concate")
371360

372361
return None
373362

@@ -475,9 +464,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
475464
2. Convert between the HF shape and the torchtitan shape.
476465
3. Concate separate expert's wegiht into GroupedExperts' weight.
477466
"""
478-
print(
479-
f"At the beginning of from_hf, the loaded state_dict is {hf_state_dict.keys()}"
480-
)
467+
481468
# dequantize the tensor in state_dict and remove the scale_inv tensor
482469

483470
hf_state_dict = self._dequantize(hf_state_dict)

torchtitan/models/llama3/model/state_dict_adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ def __init__(
2121
self,
2222
model_args: TransformerModelArgs,
2323
hf_assets_path: str | None,
24-
parallel_dims: ParallelDims,
2524
):
26-
super().__init__(model_args, hf_assets_path, parallel_dims)
25+
super().__init__(model_args, hf_assets_path)
2726

2827
self.model_args = model_args
2928
self.hf_assets_path = hf_assets_path

0 commit comments

Comments
 (0)