Skip to content

Commit 714cc5b

Browse files
committed
[dsv3] disable MoE while we fix local_map, works up until optimizer
1 parent 4f8677b commit 714cc5b

File tree

2 files changed

+37
-38
lines changed

2 files changed

+37
-38
lines changed

torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -77,43 +77,40 @@ def input_fn():
7777
mp_policy=mp_policy,
7878
compile=job_config.training.compile,
7979
) as autop:
80-
# currently errors due to missing sharding prop rules
81-
torch.distributed.breakpoint()
82-
83-
# autop.add_parameter_memory_constraint(low=None, high=None)
84-
85-
# possible_input_shardings = {
86-
# # maps relative to mesh dim names used in torchtitan
87-
# "dp_replicate": Shard(0),
88-
# "dp_shard": Shard(0),
89-
# "tp": Replicate(),
90-
# }
91-
# # only used if loss parallel is enabled
92-
# possible_output_shardings = {
93-
# # maps relative to mesh dim names used in torchtitan
94-
# "dp_shard": Shard(0),
95-
# "tp": Shard(2),
96-
# }
97-
# assert all(
98-
# name in possible_input_shardings for name in world_mesh.mesh_dim_names
99-
# ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
100-
# x_sharding = tuple(
101-
# possible_input_shardings[name] for name in world_mesh.mesh_dim_names
102-
# )
103-
# out_sharding = x_sharding
104-
# if parallel_dims.loss_parallel_enabled:
105-
# out_sharding = tuple(
106-
# possible_output_shardings[name]
107-
# for name in world_mesh.mesh_dim_names
108-
# if name != "dp_replicate"
109-
# )
110-
# autop.add_input_constraints([x_sharding])
111-
# autop.add_output_constraints([out_sharding])
112-
# t0 = time.time()
113-
# sharding_placement = autop.optimize_placement()
114-
# t1 = time.time()
115-
# logger.info(f"AutoParallel took {t1 - t0} seconds")
116-
# parallel_mod = autop.apply_placement(sharding_placement)
80+
autop.add_parameter_memory_constraint(low=None, high=None)
81+
82+
possible_input_shardings = {
83+
# maps relative to mesh dim names used in torchtitan
84+
"dp_replicate": Shard(0),
85+
"dp_shard": Shard(0),
86+
"tp": Replicate(),
87+
}
88+
# only used if loss parallel is enabled
89+
possible_output_shardings = {
90+
# maps relative to mesh dim names used in torchtitan
91+
"dp_shard": Shard(0),
92+
"tp": Shard(2),
93+
}
94+
assert all(
95+
name in possible_input_shardings for name in world_mesh.mesh_dim_names
96+
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
97+
x_sharding = tuple(
98+
possible_input_shardings[name] for name in world_mesh.mesh_dim_names
99+
)
100+
out_sharding = x_sharding
101+
if parallel_dims.loss_parallel_enabled:
102+
out_sharding = tuple(
103+
possible_output_shardings[name]
104+
for name in world_mesh.mesh_dim_names
105+
if name != "dp_replicate"
106+
)
107+
autop.add_input_constraints([x_sharding])
108+
autop.add_output_constraints([out_sharding])
109+
t0 = time.time()
110+
sharding_placement = autop.optimize_placement()
111+
t1 = time.time()
112+
logger.info(f"AutoParallel took {t1 - t0} seconds")
113+
parallel_mod = autop.apply_placement(sharding_placement)
117114

118115
if parallel_dims.loss_parallel_enabled:
119116

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
270270
self.attention = Attention(model_args)
271271
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
272272
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
273-
self.moe_enabled = layer_id >= model_args.n_dense_layers
273+
# self.moe_enabled = layer_id >= model_args.n_dense_layers
274+
# TODO: enable me when local_map works
275+
self.moe_enabled = False
274276

275277
if self.moe_enabled:
276278
self.moe = MoE(model_args)

0 commit comments

Comments
 (0)