Skip to content

Commit 7c4e7f8

Browse files
Fix vision model prompt truncation bug in DPOTrainer (huggingface#5023)
1 parent a68c82a commit 7c4e7f8

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

tests/test_dpo_trainer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,18 @@ class TestDPOVisionTrainer(TrlTestCase):
12831283
"model_id",
12841284
[
12851285
# "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now
1286-
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
1287-
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
1286+
pytest.param(
1287+
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
1288+
marks=pytest.mark.filterwarnings(
1289+
"ignore:max_prompt_length is not supported for vision models:UserWarning"
1290+
), # See #5023
1291+
),
1292+
pytest.param(
1293+
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
1294+
marks=pytest.mark.filterwarnings(
1295+
"ignore:max_prompt_length is not supported for vision models:UserWarning"
1296+
), # See #5023
1297+
),
12881298
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
12891299
],
12901300
)

trl/trainer/dpo_trainer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
import random
1717
import textwrap
18+
import warnings
1819
from collections import defaultdict
1920
from collections.abc import Callable
2021
from contextlib import contextmanager, nullcontext
@@ -775,7 +776,19 @@ def process_row(
775776
) -> dict[str, list[int]]:
776777
"""
777778
Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information.
779+
780+
Note: Unlike `tokenize_row`, this method does not truncate prompts even if `max_prompt_length` is set. For
781+
vision models, prompts contain image tokens that must exactly match the image features (pixel_values).
782+
Truncating these tokens would cause a mismatch, leading to errors during the forward pass, like "Image features
783+
and image tokens do not match". Users should filter their datasets to ensure prompts are an appropriate length
784+
before training.
778785
"""
786+
if max_prompt_length is not None:
787+
warnings.warn(
788+
"max_prompt_length is not supported for vision models and will be ignored. "
789+
"Truncating prompts would cause image token/feature mismatch errors.",
790+
stacklevel=2,
791+
)
779792
processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor
780793
processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False)
781794

@@ -794,9 +807,11 @@ def process_row(
794807
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
795808
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
796809

797-
# Truncate prompt and completion sequences
798-
if max_prompt_length is not None:
799-
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
810+
# Truncate completion sequences only.
811+
# Note: We do not truncate prompt_input_ids for vision models because the prompts contain image tokens
812+
# that must exactly match the image features (pixel_values). Truncating would cause errors like
813+
# "Image features and image tokens do not match: tokens: X, features: Y". Users should filter overlong
814+
# prompts from their dataset before training (the recommended approach for the deprecated max_prompt_length).
800815
if max_completion_length is not None:
801816
chosen_input_ids = chosen_input_ids[:max_completion_length]
802817
rejected_input_ids = rejected_input_ids[:max_completion_length]

0 commit comments

Comments
 (0)