Skip to content

Commit 2c110a8

Browse files
authored
[Tests] Increase max seq length for tracing tests (#1478)
## Purpose ## * Support `transformers>=4.52` ## Background ## * After`transformers>=4.50`, many multimodal processors now check that image tokens have not been truncated during the tokenization process. Previously we were tracing with samples which truncated image tokens, leading to technically invalid samples. This doesn't matter for tracing, but does make sense for transformers to check. ``` ValueError: Mismatch in `image` token count between text and `input_ids`. Got ids=[354] and text=[2197]. Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`. ``` ## Changes ## * Increase the `max_seq_length` for tracing tests Signed-off-by: Kyle Sayers <[email protected]>
1 parent aca81d6 commit 2c110a8

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,17 @@ def get_dataset_kwargs(modality: str) -> Dict[str, str]:
121121
"text": {
122122
"dataset": "ultrachat-200k",
123123
"splits": {"calibration": "test_sft[:1]"},
124+
"max_seq_length": 4096,
124125
},
125126
"vision": {
126127
"dataset": "flickr",
127128
"splits": {"calibration": "test[:1]"},
129+
"max_seq_length": 4096,
128130
},
129131
"audio": {
130132
"dataset": "peoples_speech",
131133
"splits": {"calibration": "test[:1]"},
134+
"max_seq_length": 4096,
132135
},
133136
}
134137

tests/llmcompressor/observers/test_min_max.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_min_max_observer(symmetric, expected_scale, expected_zero_point):
3636
tensor = torch.tensor([1, 1, 1, 1, 1])
3737
num_bits = 8
3838

39-
weights = QuantizationArgs(num_bits=num_bits,
40-
symmetric=symmetric,
41-
observer="minmax")
39+
weights = QuantizationArgs(
40+
num_bits=num_bits, symmetric=symmetric, observer="minmax"
41+
)
4242

4343
observer = weights.observer
4444
observer = Observer.load_from_registry(observer, quantization_args=weights)

0 commit comments

Comments
 (0)