Skip to content

Commit 0e99626

Browse files
committed
tests
1 parent bf1ce57 commit 0e99626

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
n_dense_layers=1,
4040
n_heads=16,
4141
moe_args=MoEArgs(
42-
num_experts=8,
42+
num_experts=2,
4343
num_shared_experts=2,
44-
top_k=3,
44+
top_k=1,
4545
score_func="softmax",
4646
route_norm=True,
4747
score_before_experts=False,

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4646

4747
[parallelism]
4848
data_parallel_replicate_degree = 1
49-
data_parallel_shard_degree = -1
49+
data_parallel_shard_degree = 4
5050
fsdp_reshard_after_forward = "default" # default / never / always
51-
tensor_parallel_degree = 1
51+
tensor_parallel_degree = 2
5252
enable_async_tensor_parallel = false
5353
pipeline_parallel_degree = 1
5454
pipeline_parallel_schedule = "1F1B"
5555
context_parallel_degree = 1
56-
expert_parallel_degree = 1
57-
expert_tensor_parallel_degree = 1
56+
expert_parallel_degree = 2
57+
expert_tensor_parallel_degree = 2
5858

5959
[checkpoint]
6060
enable_checkpoint = false

torchtitan/train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,15 @@ def train_step(
530530
def train(self):
531531
job_config = self.job_config
532532

533+
logger.info("self.parallel_dims: {}".format(self.parallel_dims))
534+
535+
w1 = self.model_parts[0].layers["1"].moe.experts.w1
536+
w2 = self.model_parts[0].layers["1"].moe.experts.w2
537+
w3 = self.model_parts[0].layers["1"].moe.experts.w3
538+
539+
logger.info(f"the GroupedExpert weights w1 shard is: {w1.placements}")
540+
logger.info(f"the GroupedExpert weights w2 shard is: {w2.placements}")
541+
logger.info(f"the GroupedExpert weights w3 shard is: {w3.placements}")
533542
self.checkpointer.load(step=job_config.checkpoint.load_step)
534543
logger.info(f"Training starts at step {self.step + 1}")
535544

0 commit comments

Comments
 (0)