Skip to content

Commit e7c39f6

Browse files
committed
debugging
1 parent 81f698a commit e7c39f6

File tree

6 files changed

+352
-166
lines changed

6 files changed

+352
-166
lines changed

torchtitan/components/checkpoint.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,16 @@ def dcp_load(
418418
)
419419

420420
state_dict = self.sd_adapter.from_hf(hf_state_dict)
421+
422+
# [rank0]:after sd converter, placement is DeviceMesh((dp_shard_mod_ep=2, dp_shard_in_ep=2, tp=2), device: 'cuda', stride: (4, 2, 1))
423+
print(
424+
f"after sd converter, placement is {state_dict['layers.3.moe.experts.w3'].device_mesh}, type {type(state_dict['layers.3.moe.experts.w3'])}, placement {state_dict['layers.3.moe.experts.w3'].placements}"
425+
)
426+
427+
# [rank0]:after sd converter, model placement is DeviceMesh((dp_shard_mod_ep=2, ep=2, tp=2), device: 'cuda', stride: (4, 2, 1))
428+
# model_state_dict = self.states[MODEL].state_dict()
429+
# print(f"after sd converter, model placement is {model_state_dict['layers.3.moe.experts.w3'].device_mesh}")
430+
421431
self.states[MODEL].load_state_dict(state_dict)
422432
else:
423433
dcp.load(state_dict, checkpoint_id=checkpoint_id)

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def parallelize_deepseekv3(
3636
job_config: JobConfig,
3737
):
3838
world_mesh = parallel_dims.world_mesh
39+
print(f"In parallelize_deepseekv3, world mesh is {world_mesh}")
3940
# TODO: TP currently cannot handle uneven seq_len because we set
4041
# `use_local_output=True` to use plain Tensors for legacy reasons.
4142
# Need to revisit this.

0 commit comments

Comments
 (0)