Skip to content

Commit 49782fc

Browse files
authored
[Misc] Some minor simplifications to detokenization logic (#3670)
Some simplifications made for clarity. Also moves detokenization-related functions from tokenizer.py to detokenizer.py.
1 parent f03cc66 commit 49782fc

File tree

3 files changed

+159
-165
lines changed

3 files changed

+159
-165
lines changed

tests/tokenization/test_detokenize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from transformers import AutoTokenizer
55

66
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
7-
from vllm.transformers_utils.detokenizer import Detokenizer
8-
from vllm.transformers_utils.tokenizer import detokenize_incrementally
7+
from vllm.transformers_utils.detokenizer import (Detokenizer,
8+
detokenize_incrementally)
99
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
1010

1111
TRUTH = [

vllm/transformers_utils/detokenizer.py

Lines changed: 156 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from typing import Dict, List, Optional
1+
from typing import Dict, List, Optional, Tuple, Union
22

3-
from transformers import PreTrainedTokenizer
3+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
44

55
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
6-
from vllm.transformers_utils.tokenizer import (convert_prompt_ids_to_tokens,
7-
detokenize_incrementally)
86
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
97
BaseTokenizerGroup)
108

@@ -148,10 +146,160 @@ def decode_sequence_inplace(self, seq: Sequence,
148146
)
149147
sample_logprob.decoded_token = new_text
150148

151-
if seq.tokens is None:
152-
seq.tokens = new_tokens
153-
else:
154-
seq.tokens.extend(new_tokens)
149+
seq.tokens.extend(new_tokens)
155150
seq.prefix_offset = prefix_offset
156151
seq.read_offset = read_offset
157152
seq.output_text += new_decoded_token_text
153+
154+
155+
def _convert_tokens_to_string_with_added_encoders(
156+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
157+
output_tokens: List[str],
158+
skip_special_tokens: bool,
159+
spaces_between_special_tokens: bool,
160+
) -> str:
161+
# Adapted from
162+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
163+
# NOTE(woosuk): The following code is slow because it runs a for loop over
164+
# the output_tokens. In Python, running a for loop over a list can be slow
165+
# even when the loop body is very simple.
166+
sub_texts = []
167+
current_sub_text = []
168+
all_special_tokens = set(tokenizer.all_special_tokens)
169+
for token in output_tokens:
170+
if skip_special_tokens and token in all_special_tokens:
171+
continue
172+
if token in tokenizer.get_added_vocab():
173+
if current_sub_text:
174+
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
175+
sub_texts.append(sub_text)
176+
current_sub_text = []
177+
sub_texts.append(token)
178+
else:
179+
current_sub_text.append(token)
180+
if current_sub_text:
181+
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
182+
sub_texts.append(sub_text)
183+
if spaces_between_special_tokens:
184+
return " ".join(sub_texts)
185+
else:
186+
return "".join(sub_texts)
187+
188+
189+
# 5 is an arbitrary value that should work for all
190+
# tokenizers (bigger = more conservative).
191+
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
192+
193+
194+
def convert_prompt_ids_to_tokens(
195+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
196+
prompt_ids: List[int],
197+
skip_special_tokens: bool = False,
198+
) -> Tuple[List[str], int, int]:
199+
"""Converts the prompt ids to tokens and returns the tokens and offsets
200+
for incremental detokenization.
201+
202+
Note that not all tokens are converted to strings. Only the tokens that
203+
are necessary for incremental detokenization are converted to strings.
204+
"""
205+
# We do not need to convert the whole prompt to tokens.
206+
# Offset a little more in case we have special tokens.
207+
new_tokens = tokenizer.convert_ids_to_tokens(
208+
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
209+
skip_special_tokens=skip_special_tokens)
210+
read_offset = len(new_tokens)
211+
prefix_offset = max(
212+
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
213+
return new_tokens, prefix_offset, read_offset
214+
215+
216+
# Based on
217+
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
218+
# under Apache 2.0 license
219+
def detokenize_incrementally(
220+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
221+
all_input_ids: List[int],
222+
prev_tokens: Optional[List[str]],
223+
prefix_offset: int,
224+
read_offset: int,
225+
skip_special_tokens: bool = False,
226+
spaces_between_special_tokens: bool = True,
227+
) -> Tuple[List[str], str, int, int]:
228+
"""Detokenizes the input ids incrementally and returns the new tokens
229+
and the new text.
230+
231+
If `prev_tokens` is None, this function will convert the input ids to
232+
tokens and return the tokens and the new text. Otherwise, it will return the
233+
new tokens and the new text.
234+
235+
This function will also return the new prefix offset and the new read
236+
offset to be used in the next iteration.
237+
238+
The offsets are necessary to defeat cleanup algorithms in the decode which
239+
decide to add a space or not depending on the surrounding ids.
240+
241+
Args:
242+
tokenizer: The tokenizer to use.
243+
all_input_ids: The input ids. The last id is the new token id.
244+
prev_tokens: The previous tokens. If None, this function will convert
245+
the input ids to tokens and return the tokens and the new text.
246+
prefix_offset: The prefix offset.
247+
read_offset: The read offset.
248+
skip_special_tokens: Whether to skip special tokens.
249+
spaces_between_special_tokens: Whether to add spaces between special
250+
tokens.
251+
"""
252+
new_token_id = all_input_ids[-1]
253+
# This is the first iteration for this sequence
254+
is_first_iter = prev_tokens is None
255+
if is_first_iter:
256+
(prev_tokens, prefix_offset,
257+
read_offset) = convert_prompt_ids_to_tokens(
258+
tokenizer,
259+
all_input_ids[:-1],
260+
skip_special_tokens=skip_special_tokens)
261+
262+
# If the new token id is out of bounds, return an empty string.
263+
if new_token_id >= len(tokenizer):
264+
new_tokens = [""]
265+
else:
266+
# Put new_token_id in a list so skip_special_tokens is respected
267+
new_tokens = tokenizer.convert_ids_to_tokens(
268+
[new_token_id], skip_special_tokens=skip_special_tokens)
269+
output_tokens = prev_tokens + new_tokens
270+
271+
# If this is the first iteration, return all tokens.
272+
if is_first_iter:
273+
new_tokens = output_tokens
274+
275+
# The prefix text is necessary only to defeat cleanup algorithms in
276+
# the decode which decide to add a space or not depending on the
277+
# surrounding ids.
278+
if tokenizer.is_fast or not tokenizer.get_added_vocab():
279+
prefix_text = tokenizer.convert_tokens_to_string(
280+
output_tokens[prefix_offset:read_offset])
281+
new_text = tokenizer.convert_tokens_to_string(
282+
output_tokens[prefix_offset:])
283+
else:
284+
prefix_text = _convert_tokens_to_string_with_added_encoders(
285+
tokenizer,
286+
output_tokens[prefix_offset:read_offset],
287+
skip_special_tokens=skip_special_tokens,
288+
spaces_between_special_tokens=spaces_between_special_tokens,
289+
)
290+
new_text = _convert_tokens_to_string_with_added_encoders(
291+
tokenizer,
292+
output_tokens[prefix_offset:],
293+
skip_special_tokens=skip_special_tokens,
294+
spaces_between_special_tokens=spaces_between_special_tokens,
295+
)
296+
297+
if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
298+
# utf-8 char at the end means it's a potential unfinished byte sequence
299+
# from byte fallback tokenization.
300+
# If it's in the middle, it's probably a real invalid id generated
301+
# by the model
302+
return new_tokens, "", prefix_offset, read_offset
303+
304+
new_text = new_text[len(prefix_text):]
305+
return new_tokens, new_text, read_offset, len(output_tokens)
Lines changed: 1 addition & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Tuple, Union
1+
from typing import Optional, Union
22

33
from transformers import (AutoTokenizer, PreTrainedTokenizer,
44
PreTrainedTokenizerFast)
@@ -126,157 +126,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
126126

127127

128128
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
129-
130-
131-
def _convert_tokens_to_string_with_added_encoders(
132-
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
133-
output_tokens: List[str],
134-
skip_special_tokens: bool,
135-
spaces_between_special_tokens: bool,
136-
) -> str:
137-
# Adapted from
138-
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
139-
# NOTE(woosuk): The following code is slow because it runs a for loop over
140-
# the output_tokens. In Python, running a for loop over a list can be slow
141-
# even when the loop body is very simple.
142-
sub_texts = []
143-
current_sub_text = []
144-
all_special_tokens = set(tokenizer.all_special_tokens)
145-
for token in output_tokens:
146-
if skip_special_tokens and token in all_special_tokens:
147-
continue
148-
if token in tokenizer.get_added_vocab():
149-
if current_sub_text:
150-
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
151-
sub_texts.append(sub_text)
152-
current_sub_text = []
153-
sub_texts.append(token)
154-
else:
155-
current_sub_text.append(token)
156-
if current_sub_text:
157-
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
158-
sub_texts.append(sub_text)
159-
if spaces_between_special_tokens:
160-
return " ".join(sub_texts)
161-
else:
162-
return "".join(sub_texts)
163-
164-
165-
# 5 is an arbitrary value that should work for all
166-
# tokenizers (bigger = more conservative).
167-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
168-
169-
170-
def convert_prompt_ids_to_tokens(
171-
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
172-
prompt_ids: List[int],
173-
skip_special_tokens: bool = False,
174-
) -> Tuple[List[str], int, int]:
175-
"""Converts the prompt ids to tokens and returns the tokens and offsets
176-
for incremental detokenization.
177-
178-
Note that not all tokens are converted to strings. Only the tokens that
179-
are necessary for incremental detokenization are converted to strings.
180-
"""
181-
# Offset a little more in case we have special tokens.
182-
prefix_offset = max(
183-
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
184-
# We do not need to convert the whole prompt to tokens.
185-
new_tokens = tokenizer.convert_ids_to_tokens(
186-
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
187-
prefix_offset = max(
188-
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
189-
read_offset = len(new_tokens)
190-
return new_tokens, prefix_offset, read_offset
191-
192-
193-
# Based on
194-
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
195-
# under Apache 2.0 license
196-
def detokenize_incrementally(
197-
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
198-
all_input_ids: List[int],
199-
prev_tokens: Optional[List[str]],
200-
prefix_offset: int,
201-
read_offset: int,
202-
skip_special_tokens: bool = False,
203-
spaces_between_special_tokens: bool = True,
204-
) -> Tuple[List[str], str, int, int]:
205-
"""Detokenizes the input ids incrementally and returns the new tokens
206-
and the new text.
207-
208-
If `prev_tokens` is None, this function will convert the input ids to
209-
tokens and return the tokens and the new text. Otherwise, it will return the
210-
new tokens and the new text.
211-
212-
This function will also return the new prefix offset and the new read
213-
offset to be used in the next iteration.
214-
215-
The offsets are necessary to defeat cleanup algorithms in the decode which
216-
decide to add a space or not depending on the surrounding ids.
217-
218-
Args:
219-
tokenizer: The tokenizer to use.
220-
all_input_ids: The input ids. The last id is the new token id.
221-
prev_tokens: The previous tokens. If None, this function will convert
222-
the input ids to tokens and return the tokens and the new text.
223-
prefix_offset: The prefix offset.
224-
read_offset: The read offset.
225-
skip_special_tokens: Whether to skip special tokens.
226-
spaces_between_special_tokens: Whether to add spaces between special
227-
tokens.
228-
"""
229-
new_token_id = all_input_ids[-1]
230-
# This is the first iteration for this sequence
231-
is_first_iter = prev_tokens is None
232-
if is_first_iter:
233-
(prev_tokens, prefix_offset,
234-
read_offset) = convert_prompt_ids_to_tokens(
235-
tokenizer,
236-
all_input_ids[:-1],
237-
skip_special_tokens=skip_special_tokens)
238-
239-
# If the new token id is out of bounds, return an empty string.
240-
if new_token_id >= len(tokenizer):
241-
new_tokens = [""]
242-
else:
243-
# Put new_token_id in a list so skip_special_tokens is respected
244-
new_tokens = tokenizer.convert_ids_to_tokens(
245-
[new_token_id], skip_special_tokens=skip_special_tokens)
246-
output_tokens = prev_tokens + new_tokens
247-
248-
# If this is the first iteration, return all tokens.
249-
if is_first_iter:
250-
new_tokens = output_tokens
251-
252-
# The prefix text is necessary only to defeat cleanup algorithms in
253-
# the decode which decide to add a space or not depending on the
254-
# surrounding ids.
255-
if tokenizer.is_fast or not tokenizer.get_added_vocab():
256-
prefix_text = tokenizer.convert_tokens_to_string(
257-
output_tokens[prefix_offset:read_offset])
258-
new_text = tokenizer.convert_tokens_to_string(
259-
output_tokens[prefix_offset:])
260-
else:
261-
prefix_text = _convert_tokens_to_string_with_added_encoders(
262-
tokenizer,
263-
output_tokens[prefix_offset:read_offset],
264-
skip_special_tokens=skip_special_tokens,
265-
spaces_between_special_tokens=spaces_between_special_tokens,
266-
)
267-
new_text = _convert_tokens_to_string_with_added_encoders(
268-
tokenizer,
269-
output_tokens[prefix_offset:],
270-
skip_special_tokens=skip_special_tokens,
271-
spaces_between_special_tokens=spaces_between_special_tokens,
272-
)
273-
274-
if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
275-
# utf-8 char at the end means it's a potential unfinished byte sequence
276-
# from byte fallback tokenization.
277-
# If it's in the middle, it's probably a real invalid id generated
278-
# by the model
279-
new_text = new_text[len(prefix_text):]
280-
return new_tokens, new_text, read_offset, len(output_tokens)
281-
else:
282-
return new_tokens, "", prefix_offset, read_offset

0 commit comments

Comments
 (0)