Skip to content

Commit 91c5639

Browse files
committed
Hook up deepseekv3_auto_parallel
This command should now run `CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel` However it doesn't actually do anything with autoparallel yet. Next step is to attach local_map to the model so that autoparallel can run.
1 parent 8e50870 commit 91c5639

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

torchtitan/experiments/auto_parallel/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1414
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer
1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers
17+
from torchtitan.models.deepseek_v3 import deepseekv3_configs, DeepSeekV3Model
1618
from .parallelize_llama import parallelize_llama
19+
from .parallelize_deepseekv3 import parallelize_deepseekv3
20+
1721

1822
register_train_spec(
1923
TrainSpec(
@@ -29,3 +33,17 @@
2933
build_loss_fn=build_cross_entropy_loss,
3034
)
3135
)
36+
register_train_spec(
37+
TrainSpec(
38+
name="deepseekv3_auto_parallel",
39+
cls=DeepSeekV3Model,
40+
config=deepseekv3_configs,
41+
parallelize_fn=parallelize_deepseekv3,
42+
pipelining_fn=None,
43+
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
44+
build_lr_schedulers_fn=build_lr_schedulers,
45+
build_dataloader_fn=build_hf_dataloader,
46+
build_tokenizer_fn=build_hf_tokenizer,
47+
build_loss_fn=build_cross_entropy_loss,
48+
)
49+
)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
9+
import torch
10+
11+
from autoparallel.api import AutoParallel
12+
13+
from torch.distributed import DeviceMesh
14+
from torch.distributed.fsdp import MixedPrecisionPolicy
15+
from torch.distributed.tensor.placement_types import Replicate, Shard
16+
17+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
18+
from torchtitan.distributed import ParallelDims
19+
20+
from torchtitan.tools.logging import logger
21+
22+
23+
def parallelize_deepseekv3(
24+
model,
25+
world_mesh: DeviceMesh,
26+
parallel_dims: ParallelDims,
27+
job_config: JobConfig,
28+
):
29+
"""
30+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
31+
parallelism to the model.
32+
33+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
34+
the model must fit on GPU or CPU memory.
35+
"""
36+
37+
def input_fn():
38+
global_batch_size = job_config.training.global_batch_size
39+
if global_batch_size < 0:
40+
# This global batch size results in 1 gradient accumulation
41+
# step.
42+
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
43+
global_batch_size = job_config.training.local_batch_size * dp_degree
44+
return (
45+
torch.randint(
46+
0,
47+
# job_config.training.vocab_size,
48+
model.vocab_size,
49+
(global_batch_size, job_config.training.seq_len),
50+
device=torch.device("cuda"),
51+
),
52+
)
53+
54+
# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
55+
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
56+
assert parallel_dims.cp_enabled is False, "CP not supported yet"
57+
assert parallel_dims.pp_enabled is False, "PP not supported yet"
58+
59+
# torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = (
60+
# lambda bucket_idx: 500 / parallel_dims.tp
61+
# )
62+
# torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = (
63+
# lambda bucket_idx: 1000 / parallel_dims.tp
64+
# )
65+
66+
# bail out
67+
return model
68+
69+
# if job_config.experimental.autop_force_bf16:
70+
# logger.info("Forcing bf16 on model")
71+
# model = model.bfloat16()
72+
73+
# param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
74+
# reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce]
75+
# mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
76+
# with AutoParallel(
77+
# model,
78+
# input_fn,
79+
# world_mesh,
80+
# mp_policy=mp_policy,
81+
# compile=job_config.training.compile,
82+
# ) as autop:
83+
# autop.add_parameter_memory_constraint(low=None, high=None)
84+
85+
# possible_input_shardings = {
86+
# # maps relative to mesh dim names used in torchtitan
87+
# "dp_replicate": Shard(0),
88+
# "dp_shard": Shard(0),
89+
# "tp": Replicate(),
90+
# }
91+
# # only used if loss parallel is enabled
92+
# possible_output_shardings = {
93+
# # maps relative to mesh dim names used in torchtitan
94+
# "dp_shard": Shard(0),
95+
# "tp": Shard(2),
96+
# }
97+
# assert all(
98+
# name in possible_input_shardings for name in world_mesh.mesh_dim_names
99+
# ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
100+
# x_sharding = tuple(
101+
# possible_input_shardings[name] for name in world_mesh.mesh_dim_names
102+
# )
103+
# out_sharding = x_sharding
104+
# if parallel_dims.loss_parallel_enabled:
105+
# out_sharding = tuple(
106+
# possible_output_shardings[name]
107+
# for name in world_mesh.mesh_dim_names
108+
# if name != "dp_replicate"
109+
# )
110+
# autop.add_input_constraints([x_sharding])
111+
# autop.add_output_constraints([out_sharding])
112+
# t0 = time.time()
113+
# sharding_placement = autop.optimize_placement()
114+
# t1 = time.time()
115+
# logger.info(f"AutoParallel took {t1 - t0} seconds")
116+
# parallel_mod = autop.apply_placement(sharding_placement)
117+
118+
if parallel_dims.loss_parallel_enabled:
119+
120+
# current PyTorch's implementation of loss parallel assumes
121+
# that the DTensor has a 1d device mesh. This is not true
122+
# in our case, but we can work around it by adding
123+
# casting the output to a DTensor on a 1d device mesh.
124+
# We should just use AutoParallel to do this for us, but
125+
# it would require putting the loss inside the model as well
126+
def _return_as_dtensor_for_loss_parallel(module, args, output):
127+
return torch.distributed.tensor.DTensor.from_local(
128+
output, world_mesh["tp"], (Shard(2),)
129+
)
130+
131+
# not keeping a reference to the hook, don't plan on
132+
# removing it at any point
133+
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)
134+
135+
return parallel_mod

0 commit comments

Comments
 (0)