We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c0e46a7 commit 309c76aCopy full SHA for 309c76a
specforge/modeling/draft/llama3_eagle.py
@@ -688,7 +688,7 @@ def forward(
688
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
689
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
690
query_states, key_states = apply_rotary_pos_emb(
691
- query_states, key_states, cos, sin, position_ids
+ query_states, key_states, cos, sin, position_ids + lck
692
)
693
694
key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -1228,7 +1228,7 @@ def forward(
1228
1229
1230
1231
- query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2
+ query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2
1232
1233
1234
if cache_hidden is not None:
0 commit comments