@@ -154,15 +154,6 @@ def __init__(
154154 self .hf_text_config = get_hf_text_config (self .hf_config )
155155 self .dtype = _get_and_verify_dtype (self .hf_text_config , dtype )
156156
157- if (getattr (self .hf_config , "max_position_embeddings" , 0 ) == 131072
158- and getattr (self .hf_config , "rope_scaling" , None ) is None ):
159- # Note(simon): this is a special case for a model that doesn't
160- # supply rope_scaling. We should remove this once the model is
161- # updated.
162- self .hf_config .update ({"rope_scaling" : {
163- "type" : "extended" ,
164- }})
165-
166157 if (not self .disable_sliding_window
167158 and self .hf_text_config .model_type == "gemma2"
168159 and self .hf_text_config .sliding_window is not None ):
@@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
14921483 derived_max_model_len = default_max_len
14931484
14941485 rope_scaling = getattr (hf_config , "rope_scaling" , None )
1495- # The correct one should be "longrope", kept "su" here
1496- # to be backward compatible
1497- if rope_scaling is not None and rope_scaling ["type" ] not in {
1498- "su" , "longrope" , "extended"
1499- }:
1500- if disable_sliding_window :
1501- # TODO(robertgshaw): Find a model that supports rope_scaling
1502- # with sliding window to see if this case should be allowed.
1503- raise NotImplementedError (
1504- "Disabling sliding window is not supported for models "
1505- "with rope_scaling. Please raise an issue so we can "
1506- "investigate." )
1507- assert "factor" in rope_scaling
1508- scaling_factor = rope_scaling ["factor" ]
1509- if rope_scaling ["type" ] == "yarn" :
1510- derived_max_model_len = rope_scaling [
1511- "original_max_position_embeddings" ]
1512- derived_max_model_len *= scaling_factor
1486+ if rope_scaling is not None :
1487+ if "type" in rope_scaling :
1488+ rope_type = rope_scaling ["type" ]
1489+ elif "rope_type" in rope_scaling :
1490+ rope_type = rope_scaling ["rope_type" ]
1491+ else :
1492+ raise ValueError (
1493+ "rope_scaling must have a 'type' or 'rope_type' key." )
1494+
1495+ # The correct one should be "longrope", kept "su" here
1496+ # to be backward compatible
1497+ if rope_type not in ("su" , "longrope" , "llama3" ):
1498+ if disable_sliding_window :
1499+ # TODO(robertgshaw): Find a model that supports rope_scaling
1500+ # with sliding window to see if this case should be allowed.
1501+ raise NotImplementedError (
1502+ "Disabling sliding window is not supported for models "
1503+ "with rope_scaling. Please raise an issue so we can "
1504+ "investigate." )
1505+
1506+ assert "factor" in rope_scaling
1507+ scaling_factor = rope_scaling ["factor" ]
1508+ if rope_type == "yarn" :
1509+ derived_max_model_len = rope_scaling [
1510+ "original_max_position_embeddings" ]
1511+ derived_max_model_len *= scaling_factor
15131512
15141513 # If the user specified a max length, make sure it is smaller than the
15151514 # derived length from the HF model config.
0 commit comments