Skip to content

Commit cd1af1a

Browse files
Sanket Jayant Purandaresanketpurandare
authored andcommitted
Fix DeepSeekV3Model for Configurable build pattern
stack-info: PR: #2725, branch: sanketpurandare/stack/4
1 parent d3d14c2 commit cd1af1a

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,31 @@ def get_sample_config() -> DeepSeekV3ModelArgs:
4747
v_head_dim=128,
4848
mscale=0.70,
4949
)
50+
51+
52+
def get_16b_sdpa_config() -> DeepSeekV3ModelArgs:
53+
return DeepSeekV3ModelArgs(
54+
vocab_size=102400,
55+
max_seq_len=4096,
56+
dim=2048,
57+
inter_dim=10944,
58+
moe_inter_dim=1408,
59+
n_layers=27,
60+
n_dense_layers=1,
61+
n_heads=16,
62+
moe_args=_MoEArgs(
63+
num_experts=64,
64+
num_shared_experts=2,
65+
top_k=6,
66+
score_func="softmax",
67+
route_norm=False,
68+
score_before_experts=False,
69+
mesh=None,
70+
),
71+
q_lora_rank=0,
72+
kv_lora_rank=512,
73+
qk_nope_head_dim=128,
74+
qk_rope_head_dim=64,
75+
v_head_dim=128,
76+
mscale=0.70,
77+
)

torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
# Need to share same base class with torchtitan models
1616
class DeepSeekV3Model(_DeepSeekV3Model, BaseModel):
17-
def __init__(self, model_args: DeepSeekV3ModelArgs):
18-
# Call _DeepSeekV3Model.__init__ which calls nn.Module.__init__
19-
# Note: We don't call BaseModel.__init__ separately because:
20-
# 1. nn.Module.__init__() is already called by _DeepSeekV3Model.__init__
21-
# 2. Calling BaseModel.__init__ after would reset all module state
22-
# (nn.Module.__init__ clears _modules, _parameters, etc.)
23-
_DeepSeekV3Model.__init__(self, model_args)
17+
def __init__(self, config: DeepSeekV3ModelArgs):
18+
_DeepSeekV3Model.__init__(self, config)
19+
20+
def verify_module_protocol(self) -> None:
21+
# Autoparallel submodules are standard nn.Modules,
22+
# not torchtitan Module instances — skip the check.
23+
pass
24+
25+
26+
# Wire Configurable pattern: build() calls DeepSeekV3Model(config=...)
27+
DeepSeekV3ModelArgs._owner = DeepSeekV3Model

0 commit comments

Comments
 (0)