Skip to content

Commit 7f1fa48

Browse files
authored
[Qwen3] Switch to verified RoPE implementation + Add weight tying support (#1590)
## Context 1. Current Qwen3 model RoPE used some trick to make numerical parity with HF. This trick is from un-official source and hard to reasoning mathematically. Switch to [torchtune based implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py#L14), which was directly contributed from Qwen team. Thanks @ebsmothers for point us to this implementation! - For RoPE embedding, I change it to the same way as complex representation based RoPE in llama3: We initialize and precompute the RoPE embedding cos/sin value only once, and pass it into Attention module during forward. In this way, TP can be applied seamlessly. - In contrast, torchtune passed the RoPE class into initialize function for each layers' attention module. 2. Add weight tying support for Qwen3, verified with FSDP + TP ## Numerical verification for RoPE Run end-to-end forward pass of Qwen3 model, the output and <img width="812" height="412" alt="Screenshot 2025-08-18 at 2 48 48 PM" src="https://github.com/user-attachments/assets/618dde58-6546-4cdf-bd8c-2b828a5afa91" /> ## Weight tying Verification: 1. With vs. without weight tying on torchtitan model: (FSDP=4, loss are exactly the same) <img width="772" height="412" alt="Screenshot 2025-08-18 at 6 19 13 PM" src="https://github.com/user-attachments/assets/c0cfa049-c5c9-42a9-9133-b6ee32e9b9b4" /> 2. torchtitan with weight tying vs. HF <img width="732" height="507" alt="Screenshot 2025-08-18 at 9 37 50 PM" src="https://github.com/user-attachments/assets/c4a51310-df02-4ade-8721-a17678445d52" /> 3. Weight tying memory address / id check: (in train.py) - passed ``` assert id(model.tok_embeddings.weight) == id(model.output.weight), "id check 2" assertEqual(model.tok_embeddings.weight, model.output.weight) # model.forward() assert id(model.tok_embeddings.weight.grad) == id(model.output.weight.grad), "id check 2" assertEqual(model.tok_embeddings.weight.grad, model.output.weight.grad) ```
1 parent 9e24689 commit 7f1fa48

File tree

5 files changed

+83
-96
lines changed

5 files changed

+83
-96
lines changed

torchtitan/experiments/qwen3/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
from .infra.parallelize import parallelize_qwen3
1818
from .model.args import Qwen3ModelArgs
19-
from .model.model import Transformer
19+
from .model.model import Qwen3Model
2020

2121
__all__ = [
2222
"parallelize_qwen3",
2323
"Qwen3ModelArgs",
24-
"Transformer",
24+
"Qwen3Model",
2525
"qwen3_configs",
2626
]
2727

@@ -107,7 +107,7 @@
107107
register_train_spec(
108108
TrainSpec(
109109
name="qwen3",
110-
model_cls=Transformer,
110+
model_cls=Qwen3Model,
111111
model_args=qwen3_configs, # Change from dict to Mapping
112112
parallelize_fn=parallelize_qwen3,
113113
pipelining_fn=None,

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def parallelize_qwen3(
120120
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121121
)
122122

123+
# Enable weight tying after applying parallelisms
124+
if model.model_args.enable_weight_tying:
125+
model.output.weight = model.tok_embeddings.weight
126+
123127
return model
124128

125129

torchtitan/experiments/qwen3/model/args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class Qwen3ModelArgs(BaseModelArgs):
3838
attn_mask_type: str = "causal"
3939
eos_id: int = 151645
4040

41+
enable_weight_tying: bool = False
42+
4143
def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
4244
seq_len = job_config.training.seq_len
4345
if seq_len > self.max_seq_len:

torchtitan/experiments/qwen3/model/model.py

Lines changed: 71 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,99 +16,77 @@
1616

1717
from .args import Qwen3ModelArgs
1818

19-
20-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
21-
"""
22-
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
23-
24-
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
25-
and the end index 'end'. The 'theta' parameter scales the frequencies.
26-
The returned tensor contains complex values in complex64 data type.
27-
28-
Args:
29-
dim (int): Dimension of the frequency tensor.
30-
end (int): End index for precomputing frequencies.
31-
theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0.
32-
33-
Returns:
34-
torch.Tensor: Precomputed frequency tensor with complex exponentials.
19+
# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py
20+
def precompute_rope_cache(
21+
dim: int, max_seq_len: int, base: float = 1_000_000.0
22+
) -> torch.Tensor:
23+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
24+
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
25+
t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device)
26+
27+
# Outer product of theta and position index; output tensor has
28+
# a shape of [max_seq_len, dim // 2]
29+
idx_theta = torch.outer(t, freqs).float()
30+
31+
# We cache the cos and sin embeddings instead of the IDs. This helps
32+
# ensure we have correct behavior when training with bf16
33+
# Size: [max_seq_len, (dim * 2)]
34+
freqs = torch.cat([idx_theta, idx_theta], dim=-1)
35+
rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1)
36+
return rope_cache
37+
38+
39+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
40+
"""Rotates half the hidden dims of the input."""
41+
x1 = x[..., : x.shape[-1] // 2]
42+
x2 = x[..., x.shape[-1] // 2 :]
43+
return torch.cat((-x2, x1), dim=-1)
44+
45+
46+
def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
3547
"""
36-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
37-
t = torch.arange(end, device=freqs.device)
38-
freqs = torch.outer(t, freqs).float()
39-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
40-
return freqs_cis
41-
42-
43-
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
44-
"""
45-
Reshape frequency tensor for broadcasting it with another tensor.
48+
Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor.
4649
4750
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
4851
for the purpose of broadcasting the frequency tensor during element-wise operations.
4952
50-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
53+
The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2),
5154
and the first seqlen elements will be sliced, but dim must match x.
5255
5356
Args:
54-
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
57+
rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped.
5558
x (torch.Tensor): Target tensor for broadcasting compatibility.
5659
5760
Returns:
5861
torch.Tensor: Reshaped frequency tensor.
5962
"""
6063
ndim = x.ndim
6164
assert ndim > 1
62-
seqlen = x.shape[1]
63-
freqs_cis = freqs_cis[0:seqlen]
64-
assert freqs_cis.shape == (seqlen, x.shape[-1])
65-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
66-
return freqs_cis.view(*shape)
65+
_, seqlen, _, head_dim = x.shape
66+
rope_cache = rope_cache[0:seqlen]
67+
# The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin
68+
assert rope_cache.shape == (seqlen, head_dim * 2)
69+
shape = [-1, seqlen, 1, head_dim * 2]
70+
return rope_cache.view(*shape)
6771

6872

6973
def apply_rotary_emb(
70-
xq: torch.Tensor,
71-
xk: torch.Tensor,
72-
freqs_cis: torch.Tensor,
74+
xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor
7375
) -> tuple[torch.Tensor, torch.Tensor]:
74-
"""
75-
Apply rotary embeddings to input tensors using the given frequency tensor.
76+
# input tensor x has shape [bsz, seq_len, num_heads, head_dim]
77+
head_dim = xq.shape[-1]
7678

77-
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
78-
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
79-
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
80-
returned as real tensors.
79+
# reshape for broadcast
80+
rope_cache = reshape_for_broadcast(rope_cache, xq)
8181

82-
Args:
83-
xq (torch.Tensor): Query tensor to apply rotary embeddings.
84-
xk (torch.Tensor): Key tensor to apply rotary embeddings.
85-
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
86-
87-
Returns:
88-
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
89-
Note:
90-
This function adds .transpose(-2,-1) to match HF implementation. This method assumes that last
91-
dimension is [real_0, real_1, ..., real_{N-1}, imag_0, imag_1, ..., imag_{N-1}] while Rope in Llama3
92-
has [real_0, imag_0, real_1, imag_1, ..., real_{N-1}, imag_{N-1}]. This is the main difference
93-
between Llama3 and Qwen3 Rope which is under investigation.
94-
"""
95-
xk_complex = torch.view_as_complex(
96-
xk.view(*xk.shape[:-1], 2, xk.shape[-1] // 2)
97-
.transpose(-2, -1)
98-
.contiguous()
99-
.float()
100-
)
101-
xq_complex = torch.view_as_complex(
102-
xq.view(*xq.shape[:-1], 2, xq.shape[-1] // 2)
103-
.transpose(-2, -1)
104-
.contiguous()
105-
.float()
106-
)
107-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
108-
109-
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3)
110-
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(3)
82+
# [bsz, seq_len, 1, head_dim]
83+
cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device)
84+
sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device)
11185

86+
# xq: [bsz, seq_len, num_heads, head_dim]
87+
# xk: [bsz, seq_len, num_kv_heads, head_dim]
88+
xq_out = (xq * cos) + (rotate_half(xq) * sin)
89+
xk_out = (xk * cos) + (rotate_half(xk) * sin)
11290
return xq_out.type_as(xq), xk_out.type_as(xk)
11391

11492

@@ -189,14 +167,13 @@ def init_weights(self, init_std: float):
189167
def forward(
190168
self,
191169
x: torch.Tensor,
192-
freqs_cis: torch.Tensor,
170+
rope_cache: torch.Tensor,
193171
):
194172
"""
195173
Forward pass of the attention module.
196174
197175
Args:
198176
x (torch.Tensor): Input tensor.
199-
freqs_cis (torch.Tensor): Precomputed frequency tensor.
200177
201178
Returns:
202179
torch.Tensor: Output tensor after attention.
@@ -220,9 +197,10 @@ def forward(
220197
if self.k_norm:
221198
xk = self.k_norm(xk)
222199

223-
# repeat k/v heads if n_kv_heads < n_heads
224-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
200+
# Apply rotary embedding
201+
xq, xk = apply_rotary_emb(xq, xk, rope_cache)
225202

203+
# repeat k/v heads if n_kv_heads < n_heads
226204
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
227205
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
228206

@@ -318,7 +296,7 @@ def __init__(self, layer_id: int, model_args: Qwen3ModelArgs):
318296
def forward(
319297
self,
320298
x: torch.Tensor,
321-
freqs_cis: torch.Tensor,
299+
rope_cache: torch.Tensor,
322300
):
323301
"""
324302
Perform a forward pass through the TransformerBlock.
@@ -331,7 +309,7 @@ def forward(
331309
torch.Tensor: Output tensor after applying attention and feedforward layers.
332310
333311
"""
334-
h = x + self.attention(self.attention_norm(x), freqs_cis)
312+
h = x + self.attention(self.attention_norm(x), rope_cache)
335313
out = h + self.feed_forward(self.ffn_norm(h))
336314
return out
337315

@@ -342,9 +320,9 @@ def init_weights(self):
342320
self.feed_forward.init_weights(self.weight_init_std)
343321

344322

345-
class Transformer(nn.Module, ModelProtocol):
323+
class Qwen3Model(nn.Module, ModelProtocol):
346324
"""
347-
Transformer Module
325+
Qwen3Model Module
348326
349327
Args:
350328
model_args (TransformerModelArgs): Model configuration arguments.
@@ -370,13 +348,18 @@ def __init__(self, model_args: Qwen3ModelArgs):
370348
self.head_dim = model_args.head_dim
371349

372350
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
373-
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
351+
352+
self.register_buffer(
353+
"rope_cache", self._precompute_rope_cache(), persistent=False
354+
)
374355

375356
self.layers = torch.nn.ModuleDict()
376357
for layer_id in range(model_args.n_layers):
377358
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
378359
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
360+
379361
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
362+
380363
self.init_weights()
381364

382365
def init_weights(
@@ -394,9 +377,9 @@ def init_weights(
394377
``init_weights``. We only call it in the constructor of this
395378
``Transformer`` root module to avoid reinitializing tensors.
396379
"""
397-
buffer_device = buffer_device or self.freqs_cis.device
380+
buffer_device = buffer_device or self.rope_cache.device
398381
with torch.device(buffer_device):
399-
self.freqs_cis = self._precompute_freqs_cis()
382+
self.rope_cache = self._precompute_rope_cache()
400383
if self.tok_embeddings is not None:
401384
nn.init.normal_(self.tok_embeddings.weight)
402385
for layer in self.layers.values():
@@ -406,6 +389,8 @@ def init_weights(
406389
self.norm.reset_parameters()
407390
final_out_std = self.model_args.dim**-0.5
408391
cutoff_factor = 3
392+
393+
# If weight tying is enabled, we don't need to initialize the output layer
409394
if self.output is not None:
410395
nn.init.trunc_normal_(
411396
self.output.weight,
@@ -415,12 +400,9 @@ def init_weights(
415400
b=cutoff_factor * final_out_std,
416401
)
417402

418-
def _precompute_freqs_cis(self) -> torch.Tensor:
419-
return precompute_freqs_cis(
420-
self.head_dim,
421-
# Need to compute until at least the max token limit for generation
422-
# TODO: explain in docs/composability.md why we removed the 2x
423-
# relaxing in our CP enablement PR
403+
def _precompute_rope_cache(self) -> torch.Tensor:
404+
return precompute_rope_cache(
405+
self.model_args.head_dim,
424406
self.model_args.max_seq_len,
425407
self.model_args.rope_theta,
426408
)
@@ -459,7 +441,7 @@ def forward(
459441
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
460442

461443
for layer in self.layers.values():
462-
h = layer(h, self.freqs_cis)
444+
h = layer(h, self.rope_cache)
463445

464446
h = self.norm(h) if self.norm else h
465447
output = self.output(h) if self.output else h

torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ save_traces_folder = "profile_trace"
88
profile_freq = 100
99

1010
[metrics]
11-
log_freq = 10
11+
log_freq = 1
1212
enable_tensorboard = false
1313
save_tb_folder = "tb"
1414

1515
[model]
1616
name = "qwen3"
1717
flavor = "0.6B"
18-
tokenizer_path = "./assets/tokenizer/Qwen3-0.6B"
18+
hf_assets_path = "./assets/hf/Qwen3-0.6B"
1919
# converters = ["float8"]
2020

2121
[optimizer]
@@ -24,7 +24,7 @@ lr = 3e-4
2424
eps = 1e-8
2525

2626
[lr_scheduler]
27-
warmup_steps = 1 # lr scheduler warm up
27+
warmup_steps = 2 # lr scheduler warm up, 20% total steps
2828

2929
[training]
3030
local_batch_size = 4
@@ -34,7 +34,6 @@ steps = 10
3434
compile = false
3535
dataset = "c4"
3636

37-
3837
[parallelism]
3938
data_parallel_replicate_degree = 1
4039
data_parallel_shard_degree = -1

0 commit comments

Comments
 (0)