Skip to content

Commit a8a008d

Browse files
authored
Fix for CB incosistency for qwen2_5_vl (#765)
For Qwen2_5_vl the `decode_inputs["position_ids"][decode_batch_id]` is of size (4,1) and the code was only updating the pos_ids of last index of last array. Therefore, changing it to update the last idx of all arrays of all the batches. --------- Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
1 parent 4bd2239 commit a8a008d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
956956
else:
957957
# If the generated sequence is valid and within generation len prepare for next decode
958958
decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1]
959-
decode_inputs["position_ids"][decode_batch_id, -1] += 1
959+
decode_inputs["position_ids"][decode_batch_id][..., -1] += 1
960960
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
961961
next_token_id[decode_batch_id, -1]
962962
)

0 commit comments

Comments
 (0)