Skip to content

Commit 78d4314

Browse files
authored
[Qwen3] Add 1.7b Config (#1671)
Contents: 1. Minor fix in qwen 0.6b training config 2. Add qwen 1.7b training cofnig 3. Fix the number of parameters calculation when weight tying is enabled
1 parent 07c00d9 commit 78d4314

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed

torchtitan/experiments/qwen3/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ QWEN3 0.6B Dense model is available for:
99
Other model sizes are added to the args, but toml file configs need to be added and tested.
1010

1111
#### Download Qwen3 tokenizer
12+
```python scripts/download_hf_assets.py --repo_id <hf_repo_name> --assets tokenizer```
1213

13-
```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --assets tokenizer```
14+
eg, for Qwen3 0.6B model, the HF repo name is `Qwen/Qwen3-0.6B`. For 1.7B model, the HF repo name is `Qwen/Qwen3-1.7B`.
1415

1516
#### Parity with HF
1617

@@ -22,5 +23,6 @@ Model parity test has been done and results suggest parity with HF implementatio
2223
- MoE alternatives
2324

2425
- Testing
26+
- Learning rate verifying: verify learning rate and schedule with real training jobs (eg, 3k stps), or find official references.
2527
- The model should be tested against established performance benchmarks
2628
- CI integration

torchtitan/experiments/qwen3/model/args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,8 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
6464
)
6565
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
6666

67+
if self.enable_weight_tying:
68+
# exclude model.token_embedding parameters from nparams
69+
nparams = nparams - nparams_embedding
70+
6771
return nparams, num_flops_per_token

torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ context_parallel_degree = 1
4444
enable = false
4545
folder = "checkpoint"
4646
interval = 500
47-
last_save_model_weights_only = false
47+
last_save_model_only = false
4848
export_dtype = "float16"
4949
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
5050

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
[job]
2+
dump_folder = "./outputs"
3+
description = "Qwen 3 1.7B training"
4+
5+
[profiling]
6+
enable_profiling = false
7+
save_traces_folder = "profile_trace"
8+
profile_freq = 100
9+
10+
[metrics]
11+
log_freq = 10
12+
enable_tensorboard = false
13+
save_tb_folder = "tb"
14+
15+
[model]
16+
name = "qwen3"
17+
flavor = "1.7B"
18+
hf_assets_path = "./assets/hf/Qwen3-1.7B"
19+
# converters = ["float8"]
20+
21+
[optimizer]
22+
name = "AdamW"
23+
lr = 3e-4
24+
eps = 1e-8
25+
26+
[lr_scheduler]
27+
warmup_steps = 20 # lr scheduler warm up, 20% total steps
28+
29+
[training]
30+
local_batch_size = 4
31+
seq_len = 4096
32+
max_norm = 1.0 # grad norm clipping
33+
steps = 100
34+
dataset = "c4"
35+
36+
[parallelism]
37+
data_parallel_replicate_degree = 1
38+
data_parallel_shard_degree = -1
39+
fsdp_reshard_after_forward = "default" # default / never / always
40+
tensor_parallel_degree = 1
41+
context_parallel_degree = 1
42+
43+
[checkpoint]
44+
enable = false
45+
folder = "checkpoint"
46+
interval = 500
47+
last_save_model_only = false
48+
export_dtype = "float16"
49+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
50+
51+
[activation_checkpoint]
52+
mode = "selective" # ["none", "selective", "full"]
53+
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
54+
55+
[compile]
56+
enable=false
57+
components = ["model", "loss"]
58+
59+
[float8]
60+
enable_fsdp_float8_all_gather = false
61+
precompute_float8_dynamic_scale_for_fsdp = false
62+
filter_fqns = ["output"]

0 commit comments

Comments
 (0)