diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py index 5a3e81bbc6..51f5c1fbda 100644 --- a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py @@ -47,3 +47,31 @@ def get_sample_config() -> DeepSeekV3ModelArgs: v_head_dim=128, mscale=0.70, ) + + +def get_16b_sdpa_config() -> DeepSeekV3ModelArgs: + return DeepSeekV3ModelArgs( + vocab_size=102400, + max_seq_len=4096, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + moe_args=_MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=None, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py index 359e45ab9f..02cfab090d 100644 --- a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py @@ -14,10 +14,14 @@ # Need to share same base class with torchtitan models class DeepSeekV3Model(_DeepSeekV3Model, BaseModel): - def __init__(self, model_args: DeepSeekV3ModelArgs): - # Call _DeepSeekV3Model.__init__ which calls nn.Module.__init__ - # Note: We don't call BaseModel.__init__ separately because: - # 1. nn.Module.__init__() is already called by _DeepSeekV3Model.__init__ - # 2. Calling BaseModel.__init__ after would reset all module state - # (nn.Module.__init__ clears _modules, _parameters, etc.) - _DeepSeekV3Model.__init__(self, model_args) + def __init__(self, config: DeepSeekV3ModelArgs): + _DeepSeekV3Model.__init__(self, config) + + def verify_module_protocol(self) -> None: + # Autoparallel submodules are standard nn.Modules, + # not torchtitan Module instances — skip the check. + pass + + +# Wire Configurable pattern: build() calls DeepSeekV3Model(config=...) +DeepSeekV3ModelArgs._owner = DeepSeekV3Model