Skip to content

Commit 309c76a

Browse files
committed
fix position id
1 parent c0e46a7 commit 309c76a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

specforge/modeling/draft/llama3_eagle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def forward(
688688
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
689689
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
690690
query_states, key_states = apply_rotary_pos_emb(
691-
query_states, key_states, cos, sin, position_ids
691+
query_states, key_states, cos, sin, position_ids + lck
692692
)
693693

694694
key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -1228,7 +1228,7 @@ def forward(
12281228
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
12291229
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
12301230
query_states, key_states = apply_rotary_pos_emb(
1231-
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2
1231+
query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2
12321232
)
12331233

12341234
if cache_hidden is not None:

0 commit comments

Comments
 (0)