|
8 | 8 | from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
9 | 9 | PreTrainedTokenizerFast)
|
10 | 10 |
|
11 |
| -from vllm.inputs import token_inputs |
12 |
| -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup |
13 |
| -from vllm.transformers_utils.detokenizer import Detokenizer |
14 |
| -from vllm.transformers_utils.tokenizer import get_tokenizer |
| 11 | +from vllm.sampling_params import SamplingParams |
15 | 12 | from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
16 | 13 | from vllm.v1.engine import EngineCoreRequest
|
17 | 14 | from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
@@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast):
|
217 | 214 |
|
218 | 215 | assert decoded_text == ''
|
219 | 216 | assert out_ids == [len(tokenizer)]
|
220 |
| - |
221 |
| - |
222 |
| -@pytest.fixture |
223 |
| -def detokenizer(tokenizer_name: str) -> Detokenizer: |
224 |
| - tokenizer = get_tokenizer( |
225 |
| - tokenizer_name, |
226 |
| - tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", |
227 |
| - trust_remote_code=False, |
228 |
| - revision=None, |
229 |
| - ) |
230 |
| - |
231 |
| - return Detokenizer(tokenizer) |
232 |
| - |
233 |
| - |
234 |
| -@pytest.fixture(name="complete_sequence_token_ids") |
235 |
| -def create_complete_sequence_token_ids(complete_sequence: str, |
236 |
| - tokenizer) -> list[int]: |
237 |
| - return tokenizer(complete_sequence, add_special_tokens=False).input_ids |
238 |
| - |
239 |
| - |
240 |
| -def create_sequence(prompt_token_ids=None): |
241 |
| - prompt_token_ids = prompt_token_ids or [] |
242 |
| - return Sequence( |
243 |
| - seq_id=0, |
244 |
| - inputs=token_inputs(prompt_token_ids), |
245 |
| - block_size=16, |
246 |
| - ) |
247 |
| - |
248 |
| - |
249 |
| -def create_dummy_logprobs( |
250 |
| - complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: |
251 |
| - return [{ |
252 |
| - token_id: Logprob(logprob=0.0), |
253 |
| - token_id + 1: Logprob(logprob=0.1) |
254 |
| - } for token_id in complete_sequence_token_ids] |
255 |
| - |
256 |
| - |
257 |
| -def create_dummy_prompt_logprobs( |
258 |
| - complete_sequence_token_ids: list[int] |
259 |
| -) -> list[Optional[dict[int, Any]]]: |
260 |
| - # logprob for the first prompt token is None. |
261 |
| - logprobs: list[Optional[dict[int, Any]]] = [None] |
262 |
| - logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) |
263 |
| - return logprobs |
264 |
| - |
265 |
| - |
266 |
| -@pytest.mark.parametrize("complete_sequence", TRUTH) |
267 |
| -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) |
268 |
| -@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) |
269 |
| -def test_decode_sequence_logprobs(complete_sequence: str, |
270 |
| - complete_sequence_token_ids: list[int], |
271 |
| - detokenizer: Detokenizer, |
272 |
| - skip_special_tokens: bool): |
273 |
| - """Verify Detokenizer decodes logprobs correctly.""" |
274 |
| - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, |
275 |
| - logprobs=2) |
276 |
| - |
277 |
| - # Run sequentially. |
278 |
| - seq = create_sequence() |
279 |
| - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) |
280 |
| - sequential_logprobs_text_chosen_token: list[str] = [] |
281 |
| - sequential_logprobs_text_other_token: list[str] = [] |
282 |
| - for new_token, logprobs in zip(complete_sequence_token_ids, |
283 |
| - dummy_logprobs): |
284 |
| - seq.append_token_id(new_token, logprobs) |
285 |
| - detokenizer.decode_sequence_inplace(seq, sampling_params) |
286 |
| - sequential_logprobs_text_chosen_token.append( |
287 |
| - seq.output_logprobs[-1][new_token].decoded_token) |
288 |
| - sequential_logprobs_text_other_token.append( |
289 |
| - seq.output_logprobs[-1][new_token + 1].decoded_token) |
290 |
| - sequential_result = seq.output_text |
291 |
| - |
292 |
| - assert sequential_result == "".join(sequential_logprobs_text_chosen_token) |
293 |
| - assert sequential_result != "".join(sequential_logprobs_text_other_token) |
294 |
| - |
295 |
| - if not skip_special_tokens: |
296 |
| - # Text for logprobs for the chosen token should be the same as the |
297 |
| - # generated text. Note that this will only be true if we skip |
298 |
| - # special tokens. |
299 |
| - assert sequential_result == complete_sequence |
300 |
| - |
301 |
| - |
302 |
| -@pytest.mark.parametrize("complete_sequence", TRUTH) |
303 |
| -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) |
304 |
| -def test_decode_prompt_logprobs(complete_sequence: str, |
305 |
| - complete_sequence_token_ids: list[int], |
306 |
| - detokenizer: Detokenizer): |
307 |
| - |
308 |
| - # We want to use skip_special_tokens=False here but Mistral tokenizers |
309 |
| - # don't support that. |
310 |
| - if complete_sequence not in SPECIAL_TOKS_TRUTH: |
311 |
| - skip_special_tokens = True |
312 |
| - elif not isinstance(detokenizer.tokenizer, MistralTokenizer): |
313 |
| - skip_special_tokens = False |
314 |
| - else: |
315 |
| - pytest.skip("MistralTokenizers don't support " |
316 |
| - "skip_special_tokens=False") |
317 |
| - return |
318 |
| - """Verify Detokenizer decodes prompt logprobs correctly.""" |
319 |
| - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, |
320 |
| - prompt_logprobs=1) |
321 |
| - |
322 |
| - # Run sequentially. |
323 |
| - seq = create_sequence(complete_sequence_token_ids) |
324 |
| - seq_group = SequenceGroup(request_id="1", |
325 |
| - seqs=[seq], |
326 |
| - sampling_params=sampling_params, |
327 |
| - arrival_time=0.0) |
328 |
| - dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) |
329 |
| - detokenizer.decode_prompt_logprobs_inplace(seq_group, |
330 |
| - dummy_logprobs, |
331 |
| - position_offset=0) |
332 |
| - # First logprob is None. |
333 |
| - decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ |
334 |
| - 1:] # type: ignore |
335 |
| - |
336 |
| - # decoded_prompt_logprobs doesn't contain the first token. |
337 |
| - token_ids = complete_sequence_token_ids |
338 |
| - tokenizer = detokenizer.tokenizer |
339 |
| - text_full = tokenizer.decode(token_ids, |
340 |
| - skip_special_tokens=skip_special_tokens) |
341 |
| - text_first = tokenizer.decode(token_ids[0], |
342 |
| - skip_special_tokens=skip_special_tokens) |
343 |
| - text = text_full[len(text_first):] |
344 |
| - |
345 |
| - # Text for logprobs for the chosen token should be the same as the |
346 |
| - # prompt text. Note that the first logprob is None. |
347 |
| - assert text == "".join([ |
348 |
| - logprobs[token_id].decoded_token |
349 |
| - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) |
350 |
| - ]) |
351 |
| - assert text != "".join([ |
352 |
| - logprobs[token_id + 1].decoded_token |
353 |
| - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) |
354 |
| - ]) |
0 commit comments