1010from dataclasses import dataclass
1111from enum import Enum
1212from pathlib import Path
13- from typing import Callable , Dict , Optional , Union
13+
14+ from typing import Any , Callable , Dict , Optional , Union
1415from abc import ABC , abstractmethod
1516
1617import torch
@@ -132,7 +133,7 @@ class TransformerArgs:
132133 ffn_dim_multiplier : Optional [int ] = None
133134 use_tiktoken : bool = False
134135 max_seq_length : int = 8192
135- use_scaled_rope : bool = False
136+ rope_scaling : Optional [ Dict [ str , Any ]] = None
136137 # For pipeline parallel
137138 n_stages : int = 1
138139 stage_idx : int = 0
@@ -418,8 +419,6 @@ def __init__(self, config: TransformerArgs) -> None:
418419 self .norm = None
419420 self .output = None
420421
421- # self.freqs_cis: Optional[Tensor] = None
422- # self.mask_cache: Optional[Tensor] = None
423422 self .max_batch_size = - 1
424423 self .max_seq_length = - 1
425424 # For supporting sequence parallel (default is off, thus value of 1)
@@ -444,7 +443,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
444443 self .config .dim // self .config .n_heads ,
445444 self .config .block_size * 2 ,
446445 self .config .rope_base ,
447- use_scaled = self .config .use_scaled_rope ,
446+ rope_scaling = self .config .rope_scaling ,
448447 )
449448 self .register_buffer ("freqs_cis" , freqs_cis , persistent = True )
450449 causal_mask = torch .tril (
@@ -681,12 +680,16 @@ def forward(self, x: Tensor) -> Tensor:
681680 return output * self .weight
682681
683682
684- def apply_scaling (freqs : torch .Tensor ):
685- # Values obtained from grid search
686- scale_factor = 8
687- low_freq_factor = 1
688- high_freq_factor = 4
689- old_context_len = 8192 # original llama3 length
683+ def apply_scaling (freqs : torch .Tensor , rope_scaling : Dict [str , Any ]):
684+ # Check for the presence of the required keys
685+ required_keys = {"factor" , "low_freq_factor" , "high_freq_factor" , "original_max_position_embeddings" }
686+ if not required_keys .issubset (rope_scaling .keys ()):
687+ raise ValueError (f"Missing required keys in apply_scaling. Expected: { required_keys } " )
688+
689+ scale_factor = rope_scaling ["factor" ]
690+ low_freq_factor = rope_scaling ["low_freq_factor" ]
691+ high_freq_factor = rope_scaling ["high_freq_factor" ]
692+ old_context_len = rope_scaling ["original_max_position_embeddings" ]
690693
691694 low_freq_wavelen = old_context_len / low_freq_factor
692695 high_freq_wavelen = old_context_len / high_freq_factor
@@ -707,16 +710,16 @@ def apply_scaling(freqs: torch.Tensor):
707710
708711
709712def precompute_freqs_cis (
710- n_elem : int , seq_len : int , base : int = 10000 , dtype = None , use_scaled : bool = False
713+ n_elem : int , seq_len : int , base : int = 10000 , dtype = None , rope_scaling : Optional [ Dict [ str , Any ]] = None
711714) -> Tensor :
712715 if not dtype :
713716 dtype = get_precision ()
714717 freqs = 1.0 / (
715718 base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem )
716719 )
717720 t = torch .arange (seq_len , device = freqs .device )
718- if use_scaled :
719- freqs = apply_scaling (freqs )
721+ if rope_scaling is not None :
722+ freqs = apply_scaling (freqs , rope_scaling )
720723 freqs = torch .outer (t , freqs )
721724 freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
722725 cache = torch .stack ([freqs_cis .real , freqs_cis .imag ], dim = - 1 )
0 commit comments