@@ -540,7 +540,7 @@ def patched__compute_dynamic_ntk_parameters(
540540 seq_len : Optional [int ] = None ,
541541 ** rope_kwargs ,
542542) -> Tuple ["torch.Tensor" , float ]:
543- """
543+ """manual patch:
544544 ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
545545
546546 Computes the inverse frequencies with NTK scaling.
@@ -594,8 +594,9 @@ def patched__compute_dynamic_ntk_parameters(
594594 seq_len = max_position_embeddings
595595 else :
596596 torch ._check (isinstance (seq_len , torch .Tensor ))
597- seq_len = torch .max (
598- seq_len , torch .Tensor (max_position_embeddings , dtype = seq_len .dtype )
597+ seq_len = torch .maximum (
598+ seq_len ,
599+ torch .tensor (max_position_embeddings , dtype = seq_len .dtype , device = seq_len .device ),
599600 )
600601
601602 # Compute the inverse frequencies
@@ -676,13 +677,23 @@ def wrapper(self, x, position_ids):
676677 """
677678
678679 def longrope_frequency_update (self , position_ids , device ):
680+ # It is no use to patch the function after the model is created
681+ # as rope_init_fn is an attribute set to one function when the model
682+ # is created and when no patch is applied yet.
683+ # So we select the patched version here.
684+ rope_init_fn = (
685+ patched__compute_dynamic_ntk_parameters
686+ if self .rope_init_fn
687+ is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters
688+ else self .rope_init_fn
689+ )
679690 seq_len = torch .max (position_ids ) + 1
680691 if hasattr (self .config , "original_max_position_embeddings" ):
681692 original_max_position_embeddings = self .config .original_max_position_embeddings
682693 else :
683694 original_max_position_embeddings = self .config .max_position_embeddings
684695 # At export time, seq_len is unknown.
685- long_inv_freq , _ = self . rope_init_fn (
696+ long_inv_freq , _ = rope_init_fn (
686697 self .config , device , seq_len = original_max_position_embeddings + 1
687698 )
688699 original_inv_freq = self .original_inv_freq .to (device )
@@ -706,6 +717,17 @@ def dynamic_frequency_update(self, position_ids, device):
706717 # - self.original_max_seq_len = config.max_position_embeddings
707718 # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
708719
720+ # It is no use to patch the function after the model is created
721+ # as rope_init_fn is an attribute set to one function when the model
722+ # is created and when no patch is applied yet.
723+ # So we select the patched version here.
724+ rope_init_fn = (
725+ patched__compute_dynamic_ntk_parameters
726+ if self .rope_init_fn
727+ is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters
728+ else self .rope_init_fn
729+ )
730+
709731 # This behaviour is difficult to translate.
710732 # The sequence always grows.
711733 # The test should always True.
@@ -729,7 +751,7 @@ def dynamic_frequency_update(self, position_ids, device):
729751 # )
730752
731753 seq_len = torch .max (position_ids ) + 1
732- long_inv_freq , self .attention_scaling = self . rope_init_fn (
754+ long_inv_freq , self .attention_scaling = rope_init_fn (
733755 self .config , device , seq_len = seq_len
734756 )
735757
0 commit comments