@@ -104,9 +104,8 @@ def __init__(self,
104
104
1 , self .total_num_key_value_heads // tp_size )
105
105
self .head_dim = self .hidden_size // self .total_num_heads
106
106
self .max_position_embeddings = config .max_position_embeddings
107
- rope_pct = getattr (config , "rope_pct" ,
108
- getattr (config , "partial_rotary_factor" , 1 ))
109
- self .rotary_ndims = int (self .head_dim * rope_pct )
107
+ self .partial_rotary_factor = getattr (
108
+ config , "rope_pct" , getattr (config , "partial_rotary_factor" , 1 ))
110
109
self .scaling = self .head_dim ** - 0.5
111
110
self .q_size = self .num_heads * self .head_dim
112
111
self .kv_size = self .num_key_value_heads * self .head_dim
@@ -130,9 +129,10 @@ def __init__(self,
130
129
prefix = f"{ prefix } .o_proj" )
131
130
self .rotary_emb = get_rope (
132
131
self .head_dim ,
133
- rotary_dim = self .rotary_ndims ,
132
+ rotary_dim = self .head_dim ,
134
133
max_position = self .config .max_position_embeddings ,
135
134
base = self .config .rope_theta ,
135
+ partial_rotary_factor = self .partial_rotary_factor ,
136
136
)
137
137
self .attn = Attention (self .num_heads ,
138
138
self .head_dim ,
0 commit comments