Skip to content

Commit e67b4f2

Browse files
WoosukKwonimoneoi
andauthored
Use FP32 in RoPE initialization (#1004)
Co-authored-by: One <[email protected]>
1 parent d6770d1 commit e67b4f2

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

tests/kernels/test_pos_encoding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ def test_rotary_embedding(
133133
device="cuda")
134134

135135
# Create the rotary embedding.
136-
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
136+
inv_freq = 1.0 / (base**(
137+
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
137138
t = torch.arange(max_position).float()
138-
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
139+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
139140
cos = freqs.cos()
140141
sin = freqs.sin()
141142
cos_sin_cache = torch.cat((cos, sin), dim=-1)

vllm/model_executor/layers/attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ def __init__(
264264
self.is_neox_style = is_neox_style
265265

266266
# Create the cos and sin cache.
267-
inv_freq = 1.0 / (base**(
268-
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
269-
t = torch.arange(max_position, device="cuda").float()
270-
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
267+
inv_freq = 1.0 / (base**(torch.arange(
268+
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
269+
t = torch.arange(max_position, dtype=torch.float, device="cuda")
270+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
271271
cos = freqs.cos()
272272
sin = freqs.sin()
273273
cache = torch.cat((cos, sin), dim=-1)

0 commit comments

Comments
 (0)