Skip to content

Commit f9e8897

Browse files
HosseinKaviani-HHossein Kavianihamedaniwwwjn
authored
Adding Qwen3 model to the experiments folder (#1429)
In this PR, I added Qwen3 0.6 B dense model for torchtitan under experiments. Parity test has been done and Torch Titan native results match HF implementation. Profiler diagnostic has been attached. Computation/communication latency breakdown displays good performance. More explanation can be found in the README file. Thanks Rohan Pandey (@KhoomeiK) for helping out with the Rope implementation.<img width="1485" height="785" alt="Screenshot 2025-08-13 at 8 31 53 AM" src="https://github.com/user-attachments/assets/d7cee483-e3f7-4008-8edc-773782ea0173" /> --------- Co-authored-by: Hossein Kavianihamedani <[email protected]> Co-authored-by: Jiani Wang <[email protected]>
1 parent 0d1b80d commit f9e8897

File tree

7 files changed

+959
-0
lines changed

7 files changed

+959
-0
lines changed

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torchtitan.experiments.llama4 # noqa: F401
8+
import torchtitan.experiments.qwen3
89
import torchtitan.experiments.simple_fsdp # noqa: F401
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
**The Qwen3 model is still under development.**
2+
3+
4+
#### Available features
5+
QWEN3 0.6B Dense model is available for:
6+
7+
- FSDP/HSDP, TP, DDP, AC, compile support
8+
9+
Other model sizes are added to the args, but toml file configs need to be added and tested. Further testing is needed to check the coistency of the parallelism implementations.
10+
11+
#### Download Qwen3 tokenizer
12+
13+
```python scripts/download_tokenizer.py --repo_id Qwen/Qwen3-0.6B```
14+
15+
16+
#### Parity with HF
17+
18+
Model parity test has been done and results suggest parity with HF implementation. Further investigation is needed to check the sanity of the Rope function.
19+
20+
#### To be added
21+
- Modeling
22+
- Variants of Dense models up to 32B
23+
- MoE alternatives
24+
- Weight tying
25+
- Testing
26+
- The model should be tested against established performance benchmarks
27+
- CI integration
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.components.validate import build_validator
14+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
15+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
17+
from .infra.parallelize import parallelize_qwen3
18+
from .model.args import Qwen3ModelArgs
19+
from .model.model import Transformer
20+
21+
__all__ = [
22+
"parallelize_qwen3",
23+
"Qwen3ModelArgs",
24+
"Transformer",
25+
"qwen3_configs",
26+
]
27+
28+
29+
# Adding different variants of the model
30+
31+
qwen3_configs = {
32+
"0.6B": Qwen3ModelArgs(
33+
vocab_size=151936,
34+
max_seq_len=4096,
35+
head_dim=128,
36+
dim=1024,
37+
n_layers=28,
38+
n_heads=16,
39+
n_kv_heads=8,
40+
qk_norm=True,
41+
hidden_dim=3072,
42+
rope_theta=1000000,
43+
),
44+
"1.7B": Qwen3ModelArgs(
45+
vocab_size=151936,
46+
max_seq_len=4096,
47+
head_dim=128,
48+
dim=2048,
49+
n_layers=28,
50+
n_heads=16,
51+
n_kv_heads=8,
52+
qk_norm=True,
53+
hidden_dim=6144,
54+
rope_theta=1000000,
55+
),
56+
"4B": Qwen3ModelArgs(
57+
vocab_size=151936,
58+
max_seq_len=4096,
59+
head_dim=128,
60+
dim=2560,
61+
n_layers=36,
62+
n_heads=32,
63+
n_kv_heads=8,
64+
qk_norm=True,
65+
hidden_dim=9728,
66+
rope_theta=1000000,
67+
),
68+
"8B": Qwen3ModelArgs(
69+
vocab_size=151936,
70+
max_seq_len=4096,
71+
head_dim=128,
72+
dim=4096,
73+
n_layers=36,
74+
n_heads=32,
75+
n_kv_heads=8,
76+
qk_norm=True,
77+
hidden_dim=12288,
78+
rope_theta=1000000,
79+
),
80+
"14B": Qwen3ModelArgs(
81+
vocab_size=151936,
82+
max_seq_len=4096,
83+
head_dim=128,
84+
dim=5120,
85+
n_layers=40,
86+
n_heads=40,
87+
n_kv_heads=8,
88+
qk_norm=True,
89+
hidden_dim=17408,
90+
rope_theta=1000000,
91+
),
92+
"32B": Qwen3ModelArgs(
93+
vocab_size=151936,
94+
max_seq_len=4096,
95+
head_dim=128,
96+
dim=5120,
97+
n_layers=64,
98+
n_heads=64,
99+
n_kv_heads=8,
100+
qk_norm=True,
101+
hidden_dim=25600,
102+
rope_theta=1000000,
103+
),
104+
}
105+
106+
107+
register_train_spec(
108+
TrainSpec(
109+
name="qwen3",
110+
model_cls=Transformer,
111+
model_args=qwen3_configs, # Change from dict to Mapping
112+
parallelize_fn=parallelize_qwen3,
113+
pipelining_fn=None,
114+
build_optimizers_fn=build_optimizers,
115+
build_lr_schedulers_fn=build_lr_schedulers,
116+
build_dataloader_fn=build_hf_dataloader,
117+
build_tokenizer_fn=build_hf_tokenizer,
118+
build_loss_fn=build_cross_entropy_loss,
119+
build_validator_fn=build_validator,
120+
)
121+
)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
8+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
from torch.distributed.device_mesh import DeviceMesh
14+
from torch.distributed.tensor import Replicate, Shard
15+
from torch.distributed.tensor.parallel import (
16+
ColwiseParallel,
17+
parallelize_module,
18+
PrepareModuleInput,
19+
RowwiseParallel,
20+
SequenceParallel,
21+
)
22+
23+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
24+
from torchtitan.distributed import ParallelDims
25+
from torchtitan.distributed.expert_parallel import NoParallel
26+
from torchtitan.models.llama3.infra.parallelize import (
27+
apply_ac,
28+
apply_compile,
29+
apply_ddp,
30+
apply_fsdp,
31+
)
32+
from torchtitan.tools.logging import logger
33+
34+
35+
def parallelize_qwen3(
36+
model: nn.Module,
37+
parallel_dims: ParallelDims,
38+
job_config: JobConfig,
39+
):
40+
41+
world_mesh = parallel_dims.world_mesh
42+
assert (
43+
job_config.training.seq_len % parallel_dims.seq_len_divisor == 0
44+
), f"""
45+
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
46+
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
47+
"""
48+
if parallel_dims.tp_enabled:
49+
if (
50+
job_config.parallelism.enable_async_tensor_parallel
51+
and not job_config.training.compile
52+
):
53+
raise RuntimeError("Async TP requires --training.compile")
54+
55+
enable_float8_linear = "float8" in job_config.model.converters
56+
float8_is_rowwise = job_config.float8.recipe_name in (
57+
"rowwise",
58+
"rowwise_with_gw_hp",
59+
)
60+
61+
# For now, float8 all-gather with TP is only supported for tensorwise
62+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
63+
# all-gather happens in high precision.
64+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
65+
66+
apply_tp(
67+
model,
68+
world_mesh["tp"],
69+
loss_parallel=not job_config.parallelism.disable_loss_parallel,
70+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
71+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
72+
)
73+
74+
if job_config.activation_checkpoint.mode != "none":
75+
apply_ac(model, job_config.activation_checkpoint)
76+
77+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
78+
if job_config.training.compile:
79+
apply_compile(model)
80+
81+
if parallel_dims.fsdp_enabled:
82+
# apply FSDP or HSDP, potentially with Context Parallel
83+
if parallel_dims.dp_replicate_enabled:
84+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
85+
else:
86+
dp_mesh_dim_names = ("dp_shard_cp",)
87+
88+
apply_fsdp(
89+
model,
90+
world_mesh[tuple(dp_mesh_dim_names)],
91+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
92+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
93+
pp_enabled=parallel_dims.pp_enabled,
94+
cpu_offload=job_config.training.enable_cpu_offload,
95+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
96+
)
97+
98+
if parallel_dims.dp_replicate_enabled:
99+
logger.info("Applied HSDP to the model")
100+
else:
101+
logger.info("Applied FSDP to the model")
102+
103+
if parallel_dims.dp_replicate_enabled:
104+
logger.info("Applied HSDP to the model")
105+
else:
106+
logger.info("Applied FSDP to the model")
107+
108+
if parallel_dims.cp_enabled:
109+
logger.info("Applied Context Parallel to the model")
110+
111+
if job_config.training.enable_cpu_offload:
112+
logger.info("Applied CPU Offloading to the model")
113+
elif parallel_dims.dp_replicate_enabled:
114+
if world_mesh.ndim > 1:
115+
raise RuntimeError("DDP has not supported > 1D parallelism")
116+
apply_ddp(
117+
model,
118+
world_mesh,
119+
enable_compile=job_config.training.compile,
120+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121+
)
122+
123+
return model
124+
125+
126+
def apply_tp(
127+
model: nn.Module,
128+
tp_mesh: DeviceMesh,
129+
loss_parallel: bool,
130+
enable_float8_tensorwise_tp: bool,
131+
enable_async_tp: bool,
132+
):
133+
"""Apply tensor parallelism."""
134+
# 1. Parallelize the embedding and shard its outputs (which are the first
135+
# transformer block's inputs)
136+
# 2. Parallelize the root norm layer over the sequence dim
137+
# 3. Parallelize the final linear output layer
138+
parallelize_module(
139+
model,
140+
tp_mesh,
141+
{
142+
"tok_embeddings": RowwiseParallel(
143+
input_layouts=Replicate(),
144+
output_layouts=Shard(1),
145+
),
146+
"norm": SequenceParallel(),
147+
"output": ColwiseParallel(
148+
input_layouts=Shard(1),
149+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
150+
use_local_output=not loss_parallel,
151+
),
152+
},
153+
)
154+
155+
# Parallel styles used for transformer block linear weights and their
156+
# inputs may be different for float8 linears with tensorwise scaling.
157+
if enable_float8_tensorwise_tp:
158+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
159+
from torchao.float8.float8_tensor_parallel import (
160+
Float8ColwiseParallel,
161+
Float8RowwiseParallel,
162+
PrepareFloat8ModuleInput,
163+
)
164+
165+
rowwise_parallel, colwise_parallel, prepare_module_input = (
166+
Float8RowwiseParallel,
167+
Float8ColwiseParallel,
168+
PrepareFloat8ModuleInput,
169+
)
170+
else:
171+
rowwise_parallel, colwise_parallel, prepare_module_input = (
172+
RowwiseParallel,
173+
ColwiseParallel,
174+
PrepareModuleInput,
175+
)
176+
177+
# Apply tensor + sequence parallelism to every transformer block
178+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
179+
# by folding (and unfolding) the batch dimension and the sequence dimension.
180+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
181+
for transformer_block in model.layers.values():
182+
layer_plan = {
183+
"attention_norm": SequenceParallel(),
184+
"attention": prepare_module_input(
185+
input_layouts=(Shard(1), Replicate()),
186+
desired_input_layouts=(Replicate(), Replicate()),
187+
),
188+
"attention.wq": colwise_parallel(use_local_output=False),
189+
"attention.wk": colwise_parallel(use_local_output=False),
190+
"attention.wv": colwise_parallel(use_local_output=False),
191+
"attention.q_norm": NoParallel(use_local_output=False),
192+
"attention.k_norm": NoParallel(use_local_output=False),
193+
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
194+
"ffn_norm": SequenceParallel(),
195+
"feed_forward": prepare_module_input(
196+
input_layouts=(Shard(1),),
197+
desired_input_layouts=(Replicate(),),
198+
),
199+
"feed_forward.w1": colwise_parallel(),
200+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
201+
"feed_forward.w3": colwise_parallel(),
202+
}
203+
204+
parallelize_module(
205+
module=transformer_block,
206+
device_mesh=tp_mesh,
207+
parallelize_plan=layer_plan,
208+
)
209+
210+
if enable_async_tp:
211+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
212+
213+
torch._inductor.config._micro_pipeline_tp = True
214+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
215+
216+
logger.info(
217+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
218+
"Tensor Parallelism to the model"
219+
)

0 commit comments

Comments
 (0)