@@ -264,6 +264,15 @@ def __init__(
264
264
self .is_neox_style = is_neox_style
265
265
266
266
# Create the cos and sin cache.
267
+ # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
268
+ # However, we use `torch.arange(..., dtype=torch.float)` instead to
269
+ # avoid numerical issues with large base values (e.g., 10000000).
270
+ # This may cause a slight numerical difference between the HF
271
+ # implementation and ours.
272
+ # NOTE(woosuk): To exactly match the HF implementation, we need to
273
+ # use CPU to compute the cache and then move it to GPU. However, we
274
+ # create the cache on GPU for faster initialization. This may cause
275
+ # a slight numerical difference between the HF implementation and ours.
267
276
inv_freq = 1.0 / (base ** (torch .arange (
268
277
0 , rotary_dim , 2 , dtype = torch .float , device = "cuda" ) / rotary_dim ))
269
278
t = torch .arange (max_position , dtype = torch .float , device = "cuda" )
@@ -274,7 +283,6 @@ def __init__(
274
283
275
284
# FIXME(woosuk): This assumes that we configure the default dtype when
276
285
# initializing the model.
277
- # TODO(woosuk): Make it more robust.
278
286
torch_dtype = torch .get_default_dtype ()
279
287
cache = cache .to (torch_dtype )
280
288
# Embedding size: [max_position, rotary_dim]
0 commit comments