16
16
17
17
from .args import Qwen3ModelArgs
18
18
19
-
20
- def precompute_freqs_cis (dim : int , end : int , theta : float = 10000.0 ) -> torch .Tensor :
21
- """
22
- Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
23
-
24
- This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
25
- and the end index 'end'. The 'theta' parameter scales the frequencies.
26
- The returned tensor contains complex values in complex64 data type.
27
-
28
- Args:
29
- dim (int): Dimension of the frequency tensor.
30
- end (int): End index for precomputing frequencies.
31
- theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0.
32
-
33
- Returns:
34
- torch.Tensor: Precomputed frequency tensor with complex exponentials.
19
+ # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py
20
+ def precompute_rope_cache (
21
+ dim : int , max_seq_len : int , base : float = 1_000_000.0
22
+ ) -> torch .Tensor :
23
+ freqs = 1.0 / (base ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
24
+ # Create position indexes `[0, 1, ..., max_seq_len - 1]`
25
+ t = torch .arange (max_seq_len , dtype = freqs .dtype , device = freqs .device )
26
+
27
+ # Outer product of theta and position index; output tensor has
28
+ # a shape of [max_seq_len, dim // 2]
29
+ idx_theta = torch .outer (t , freqs ).float ()
30
+
31
+ # We cache the cos and sin embeddings instead of the IDs. This helps
32
+ # ensure we have correct behavior when training with bf16
33
+ # Size: [max_seq_len, (dim * 2)]
34
+ freqs = torch .cat ([idx_theta , idx_theta ], dim = - 1 )
35
+ rope_cache = torch .cat ([freqs .cos (), freqs .sin ()], dim = - 1 )
36
+ return rope_cache
37
+
38
+
39
+ def rotate_half (x : torch .Tensor ) -> torch .Tensor :
40
+ """Rotates half the hidden dims of the input."""
41
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
42
+ x2 = x [..., x .shape [- 1 ] // 2 :]
43
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
44
+
45
+
46
+ def reshape_for_broadcast (rope_cache : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
35
47
"""
36
- freqs = 1.0 / (theta ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
37
- t = torch .arange (end , device = freqs .device )
38
- freqs = torch .outer (t , freqs ).float ()
39
- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64
40
- return freqs_cis
41
-
42
-
43
- def reshape_for_broadcast (freqs_cis : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
44
- """
45
- Reshape frequency tensor for broadcasting it with another tensor.
48
+ Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor.
46
49
47
50
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
48
51
for the purpose of broadcasting the frequency tensor during element-wise operations.
49
52
50
- The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim ),
53
+ The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2 ),
51
54
and the first seqlen elements will be sliced, but dim must match x.
52
55
53
56
Args:
54
- freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
57
+ rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped.
55
58
x (torch.Tensor): Target tensor for broadcasting compatibility.
56
59
57
60
Returns:
58
61
torch.Tensor: Reshaped frequency tensor.
59
62
"""
60
63
ndim = x .ndim
61
64
assert ndim > 1
62
- seqlen = x .shape [1 ]
63
- freqs_cis = freqs_cis [0 :seqlen ]
64
- assert freqs_cis .shape == (seqlen , x .shape [- 1 ])
65
- shape = [d if i == 1 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
66
- return freqs_cis .view (* shape )
65
+ _ , seqlen , _ , head_dim = x .shape
66
+ rope_cache = rope_cache [0 :seqlen ]
67
+ # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin
68
+ assert rope_cache .shape == (seqlen , head_dim * 2 )
69
+ shape = [- 1 , seqlen , 1 , head_dim * 2 ]
70
+ return rope_cache .view (* shape )
67
71
68
72
69
73
def apply_rotary_emb (
70
- xq : torch .Tensor ,
71
- xk : torch .Tensor ,
72
- freqs_cis : torch .Tensor ,
74
+ xq : torch .Tensor , xk : torch .Tensor , rope_cache : torch .Tensor
73
75
) -> tuple [torch .Tensor , torch .Tensor ]:
74
- """
75
- Apply rotary embeddings to input tensors using the given frequency tensor.
76
+ # input tensor x has shape [bsz, seq_len, num_heads, head_dim]
77
+ head_dim = xq . shape [ - 1 ]
76
78
77
- This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
78
- frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
79
- is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
80
- returned as real tensors.
79
+ # reshape for broadcast
80
+ rope_cache = reshape_for_broadcast (rope_cache , xq )
81
81
82
- Args:
83
- xq (torch.Tensor): Query tensor to apply rotary embeddings.
84
- xk (torch.Tensor): Key tensor to apply rotary embeddings.
85
- freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
86
-
87
- Returns:
88
- tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
89
- Note:
90
- This function adds .transpose(-2,-1) to match HF implementation. This method assumes that last
91
- dimension is [real_0, real_1, ..., real_{N-1}, imag_0, imag_1, ..., imag_{N-1}] while Rope in Llama3
92
- has [real_0, imag_0, real_1, imag_1, ..., real_{N-1}, imag_{N-1}]. This is the main difference
93
- between Llama3 and Qwen3 Rope which is under investigation.
94
- """
95
- xk_complex = torch .view_as_complex (
96
- xk .view (* xk .shape [:- 1 ], 2 , xk .shape [- 1 ] // 2 )
97
- .transpose (- 2 , - 1 )
98
- .contiguous ()
99
- .float ()
100
- )
101
- xq_complex = torch .view_as_complex (
102
- xq .view (* xq .shape [:- 1 ], 2 , xq .shape [- 1 ] // 2 )
103
- .transpose (- 2 , - 1 )
104
- .contiguous ()
105
- .float ()
106
- )
107
- freqs_cis = reshape_for_broadcast (freqs_cis , xq_complex )
108
-
109
- xq_out = torch .view_as_real (xq_complex * freqs_cis ).flatten (3 )
110
- xk_out = torch .view_as_real (xk_complex * freqs_cis ).flatten (3 )
82
+ # [bsz, seq_len, 1, head_dim]
83
+ cos = rope_cache [..., :head_dim ].to (dtype = xq .dtype , device = xq .device )
84
+ sin = rope_cache [..., head_dim :].to (dtype = xq .dtype , device = xq .device )
111
85
86
+ # xq: [bsz, seq_len, num_heads, head_dim]
87
+ # xk: [bsz, seq_len, num_kv_heads, head_dim]
88
+ xq_out = (xq * cos ) + (rotate_half (xq ) * sin )
89
+ xk_out = (xk * cos ) + (rotate_half (xk ) * sin )
112
90
return xq_out .type_as (xq ), xk_out .type_as (xk )
113
91
114
92
@@ -189,14 +167,13 @@ def init_weights(self, init_std: float):
189
167
def forward (
190
168
self ,
191
169
x : torch .Tensor ,
192
- freqs_cis : torch .Tensor ,
170
+ rope_cache : torch .Tensor ,
193
171
):
194
172
"""
195
173
Forward pass of the attention module.
196
174
197
175
Args:
198
176
x (torch.Tensor): Input tensor.
199
- freqs_cis (torch.Tensor): Precomputed frequency tensor.
200
177
201
178
Returns:
202
179
torch.Tensor: Output tensor after attention.
@@ -220,9 +197,10 @@ def forward(
220
197
if self .k_norm :
221
198
xk = self .k_norm (xk )
222
199
223
- # repeat k/v heads if n_kv_heads < n_heads
224
- xq , xk = apply_rotary_emb (xq , xk , freqs_cis = freqs_cis )
200
+ # Apply rotary embedding
201
+ xq , xk = apply_rotary_emb (xq , xk , rope_cache )
225
202
203
+ # repeat k/v heads if n_kv_heads < n_heads
226
204
keys = repeat_kv (xk , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
227
205
values = repeat_kv (xv , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
228
206
@@ -318,7 +296,7 @@ def __init__(self, layer_id: int, model_args: Qwen3ModelArgs):
318
296
def forward (
319
297
self ,
320
298
x : torch .Tensor ,
321
- freqs_cis : torch .Tensor ,
299
+ rope_cache : torch .Tensor ,
322
300
):
323
301
"""
324
302
Perform a forward pass through the TransformerBlock.
@@ -331,7 +309,7 @@ def forward(
331
309
torch.Tensor: Output tensor after applying attention and feedforward layers.
332
310
333
311
"""
334
- h = x + self .attention (self .attention_norm (x ), freqs_cis )
312
+ h = x + self .attention (self .attention_norm (x ), rope_cache )
335
313
out = h + self .feed_forward (self .ffn_norm (h ))
336
314
return out
337
315
@@ -342,9 +320,9 @@ def init_weights(self):
342
320
self .feed_forward .init_weights (self .weight_init_std )
343
321
344
322
345
- class Transformer (nn .Module , ModelProtocol ):
323
+ class Qwen3Model (nn .Module , ModelProtocol ):
346
324
"""
347
- Transformer Module
325
+ Qwen3Model Module
348
326
349
327
Args:
350
328
model_args (TransformerModelArgs): Model configuration arguments.
@@ -370,13 +348,18 @@ def __init__(self, model_args: Qwen3ModelArgs):
370
348
self .head_dim = model_args .head_dim
371
349
372
350
self .tok_embeddings = nn .Embedding (model_args .vocab_size , model_args .dim )
373
- self .register_buffer ("freqs_cis" , self ._precompute_freqs_cis (), persistent = True )
351
+
352
+ self .register_buffer (
353
+ "rope_cache" , self ._precompute_rope_cache (), persistent = False
354
+ )
374
355
375
356
self .layers = torch .nn .ModuleDict ()
376
357
for layer_id in range (model_args .n_layers ):
377
358
self .layers [str (layer_id )] = TransformerBlock (layer_id , model_args )
378
359
self .norm = nn .RMSNorm (model_args .dim , eps = model_args .norm_eps )
360
+
379
361
self .output = nn .Linear (model_args .dim , model_args .vocab_size , bias = False )
362
+
380
363
self .init_weights ()
381
364
382
365
def init_weights (
@@ -394,9 +377,9 @@ def init_weights(
394
377
``init_weights``. We only call it in the constructor of this
395
378
``Transformer`` root module to avoid reinitializing tensors.
396
379
"""
397
- buffer_device = buffer_device or self .freqs_cis .device
380
+ buffer_device = buffer_device or self .rope_cache .device
398
381
with torch .device (buffer_device ):
399
- self .freqs_cis = self ._precompute_freqs_cis ()
382
+ self .rope_cache = self ._precompute_rope_cache ()
400
383
if self .tok_embeddings is not None :
401
384
nn .init .normal_ (self .tok_embeddings .weight )
402
385
for layer in self .layers .values ():
@@ -406,6 +389,8 @@ def init_weights(
406
389
self .norm .reset_parameters ()
407
390
final_out_std = self .model_args .dim ** - 0.5
408
391
cutoff_factor = 3
392
+
393
+ # If weight tying is enabled, we don't need to initialize the output layer
409
394
if self .output is not None :
410
395
nn .init .trunc_normal_ (
411
396
self .output .weight ,
@@ -415,12 +400,9 @@ def init_weights(
415
400
b = cutoff_factor * final_out_std ,
416
401
)
417
402
418
- def _precompute_freqs_cis (self ) -> torch .Tensor :
419
- return precompute_freqs_cis (
420
- self .head_dim ,
421
- # Need to compute until at least the max token limit for generation
422
- # TODO: explain in docs/composability.md why we removed the 2x
423
- # relaxing in our CP enablement PR
403
+ def _precompute_rope_cache (self ) -> torch .Tensor :
404
+ return precompute_rope_cache (
405
+ self .model_args .head_dim ,
424
406
self .model_args .max_seq_len ,
425
407
self .model_args .rope_theta ,
426
408
)
@@ -459,7 +441,7 @@ def forward(
459
441
h = self .tok_embeddings (tokens ) if self .tok_embeddings else tokens
460
442
461
443
for layer in self .layers .values ():
462
- h = layer (h , self .freqs_cis )
444
+ h = layer (h , self .rope_cache )
463
445
464
446
h = self .norm (h ) if self .norm else h
465
447
output = self .output (h ) if self .output else h
0 commit comments