Skip to content

Commit 3f8eebc

Browse files
committed
fix batch & add tests
1. add tests 2. fix batch inference 3. modify the tokenizer 4. add some comments Signed-off-by: guochenxu <[email protected]>
1 parent 0288f55 commit 3f8eebc

File tree

8 files changed

+394
-519
lines changed

8 files changed

+394
-519
lines changed

src/transformers/models/minicpm_o_2_6/feature_extractor_minicpm_o_2_6.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __call__(
3636
chunk_length: Optional[int] = 1,
3737
**kwargs,
3838
):
39+
# in batch inference, it may be [[]]
3940
if isinstance(audios, np.ndarray):
4041
audios_list = [[audios]]
4142
elif isinstance(audios[0], np.ndarray):

src/transformers/models/minicpm_o_2_6/image_processing_minicpm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,13 @@ def preprocess(
259259
return_tensors: Optional[Union[str, TensorType]] = None,
260260
**kwargs,
261261
) -> MiniCPMOBatchFeature:
262-
images_list = make_nested_list_of_images(images)
262+
# in batch inference, it may be [[]], so we can't use `make_nested_list_of_images`
263+
if isinstance(images, Image.Image):
264+
images_list = [[images]]
265+
elif isinstance(images[0], Image.Image):
266+
images_list = [images]
267+
else:
268+
images_list = images
263269

264270
to_tensor = transforms.ToTensor()
265271
normalize_transform = transforms.Normalize(
@@ -308,7 +314,9 @@ def preprocess(
308314
(slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
309315
)
310316

311-
tgt_sizes = np.vstack(tgt_sizes)
317+
# in batch inference, it may be []
318+
if tgt_sizes:
319+
tgt_sizes = np.vstack(tgt_sizes)
312320

313321
new_images_list.append(new_images)
314322
image_sizes_list.append(image_sizes)

src/transformers/models/minicpm_o_2_6/modeling_minicpm_o_2_6.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def __init__(self, config):
551551
# feature_extractor = MiniCPM_o_2_6FeatureExtractor.from_pretrained(config._name_or_path)
552552
# self.processor = MiniCPM_o_2_6Processor(image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer)
553553

554-
self.terminators = ["<|im_end|>", "<|endoftext|>"]
554+
# self.terminators = ["<|im_end|>", "<|endoftext|>"]
555555

556556
self.force_no_stop = False
557557

@@ -1094,26 +1094,12 @@ def forward(
10941094
attentions=outputs.attentions,
10951095
)
10961096

1097-
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
1098-
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1099-
outputs = super().generate(
1100-
inputs_embeds=inputs_embeds,
1101-
pad_token_id=0,
1102-
eos_token_id=terminators,
1103-
attention_mask=attention_mask,
1104-
output_hidden_states=True,
1105-
return_dict_in_generate=True,
1106-
**kwargs,
1107-
)
1108-
return outputs
1109-
11101097
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
1111-
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
11121098
streamer = TextIteratorStreamer(tokenizer=tokenizer)
11131099
generation_kwargs = {
11141100
"inputs_embeds": inputs_embeds,
11151101
"pad_token_id": 0,
1116-
"eos_token_id": terminators,
1102+
"eos_token_id": tokenizer.terminator_ids,
11171103
"streamer": streamer,
11181104
}
11191105
generation_kwargs.update(kwargs)
@@ -1199,11 +1185,10 @@ def generate(
11991185
if stream:
12001186
result = self._decode_stream(model_inputs["inputs_embeds"], processor.tokenizer, **generation_config)
12011187
else:
1202-
terminators = [processor.tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
12031188
outputs = super().generate(
12041189
inputs_embeds=model_inputs["inputs_embeds"],
12051190
pad_token_id=0,
1206-
eos_token_id=terminators,
1191+
eos_token_id=processor.tokenizer.terminator_ids,
12071192
attention_mask=attention_mask,
12081193
output_hidden_states=True,
12091194
return_dict_in_generate=True,
@@ -1213,7 +1198,7 @@ def generate(
12131198
if stream:
12141199
def stream_gen():
12151200
for text in result:
1216-
for term in self.terminators:
1201+
for term in processor.tokenizer.terminators:
12171202
text = text.replace(term, "")
12181203
yield text
12191204

@@ -1226,8 +1211,7 @@ def stream_gen():
12261211
spk_embeds = wav_numpy = sr = None
12271212

12281213
if not batched and use_tts_template and generate_audio:
1229-
# todo 这个地方怎么处理,必须得decode一次
1230-
result = processor.decode_text(outputs.sequences, processor.tokenizer, self.terminators)
1214+
result = processor.decode_text(outputs.sequences, processor.tokenizer)
12311215
mel_spec = self._generate_mel_spec(model_inputs, outputs, result[0], tts_config={'top_p': 0.7, 'top_k': 20, 'repetition_penalty': 1.0}, force_no_stop=force_no_stop)
12321216
wav_numpy, sr = self.decode_mel_to_audio(mel_spec, kwargs.get('output_audio_path', None))
12331217

@@ -1486,7 +1470,6 @@ def streaming_generate(
14861470
self.llm_generate_completed = False
14871471
self.audio_past_key_values = None # apm kv cache
14881472

1489-
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
14901473
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
14911474
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
14921475

@@ -1500,7 +1483,7 @@ def streaming_generate(
15001483
attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device)
15011484

15021485
generation_config["max_new_tokens"] = max_new_tokens
1503-
streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, terminators, generation_config)
1486+
streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, tokenizer.terminator_ids, generation_config)
15041487

15051488
if generate_audio:
15061489
result = self._generate_mel_spec_audio_streaming(
@@ -1552,7 +1535,7 @@ def check_uncompleted_token(ids):
15521535
end = check_uncompleted_token(cur_ids[0])
15531536
left_ids = cur_ids[:, end:]
15541537
cur_ids = cur_ids[:, :end]
1555-
text = self.processor.decode_text(cur_ids, tokenizer, self.terminators)[0] if end > 0 else ""
1538+
text = self.processor.decode_text(cur_ids, tokenizer)[0] if end > 0 else ""
15561539

15571540
self.llm_past_key_values = outputs.past_key_values
15581541
input_ids = outputs.sequences[:, -1:]

src/transformers/models/minicpm_o_2_6/processing_minicpm_o_2_6.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,11 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
158158
image_processor_class = "AutoImageProcessor"
159159
feature_extractor_class = "MiniCPM_o_2_6FeatureExtractor"
160160

161-
def __init__(self, tokenizer=None, image_processor=None, feature_extractor=None):
162-
super().__init__(tokenizer, image_processor, feature_extractor)
161+
def __init__(self, tokenizer=None, image_processor=None, feature_extractor=None, chat_template=None):
162+
super().__init__(tokenizer, image_processor,
163+
feature_extractor, chat_template=chat_template)
163164
self.version = image_processor.version
164165
self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"
165-
self.image_tag = "(<image>./</image>)"
166-
self.image_pattern = "\(<image>./</image>\)"
167-
self.audio_tag = "(<audio>./</audio>)"
168-
self.audio_pattern = "\(<audio>./</audio>\)"
169-
self.terminators = ["<|im_end|>", "<|endoftext|>"]
170-
self.split_pattern = f"({self.image_pattern}|{self.audio_pattern})"
171166

172167
def __call__(
173168
self,
@@ -221,18 +216,19 @@ def apply_chat_template(
221216
msgs,
222217
chunk_input=True,
223218
max_slice_nums=None,
224-
max_inp_length=32768,
219+
max_length=32768,
225220
omni_input=False,
226221
use_image_id=None,
227222
use_tts_template=False,
223+
**kwargs,
228224
):
229225
"""
230226
Unified chat function
231227
232228
Args:
233229
msgs: the input chat msgs, support text: (string) / image: (PIL.Image) / audio (numpy.ndarray)
234230
chunk_input: whether to split audio into 1s chunks
235-
max_inp_length: the maximum length of input
231+
max_length: the maximum length of input
236232
max_slice_nums: control the maximum number of image slices
237233
omni_input: determine whether it is omni mode
238234
use_image_id: for video understanding or omni understanding, use_image_id should be False
@@ -295,11 +291,9 @@ def apply_chat_template(
295291
chat_template=self.default_tts_chat_template if use_tts_template else None,
296292
)
297293
)
298-
if images:
299-
input_images_list.append(images)
300-
if audios:
301-
input_audios_list.append(audios)
302-
audio_parts_list.append(audio_parts)
294+
input_images_list.append(images)
295+
input_audios_list.append(audios)
296+
audio_parts_list.append(audio_parts)
303297

304298
inputs = self.__call__(
305299
prompts_lists,
@@ -310,13 +304,13 @@ def apply_chat_template(
310304
use_image_id=use_image_id,
311305
chunk_input=chunk_input,
312306
return_tensors="pt",
313-
max_length=max_inp_length,
307+
max_length=max_length,
314308
)
315309
return inputs
316310

317311
def decode(self, outputs, batched=False):
318312
result = self.decode_text(
319-
outputs.sequences, self.tokenizer, self.terminators)
313+
outputs.sequences, self.tokenizer)
320314
if not batched:
321315
result = result[0]
322316
if isinstance(result, list):
@@ -325,15 +319,22 @@ def decode(self, outputs, batched=False):
325319
result = result.replace(self.tokenizer.tts_end, "")
326320
return result
327321

328-
def decode_text(self, result_ids, tokenizer, terminators):
329-
terminators = [tokenizer.convert_tokens_to_ids(i) for i in terminators]
322+
def decode_text(self, result_ids, tokenizer):
330323
result_text = []
331324
for result in result_ids:
332325
result = result[result != 0]
333-
if result[0] == tokenizer.bos_id:
334-
result = result[1:]
335-
if result[-1] in terminators:
336-
result = result[:-1]
326+
start, end = 0, len(result)
327+
for i, tok in enumerate(result):
328+
if tok == tokenizer.bos_id:
329+
start = i+1
330+
else:
331+
break
332+
for i in range(len(result)-1, -1, -1):
333+
if result[i] in tokenizer.terminator_ids:
334+
end = i
335+
else:
336+
break
337+
result = result[start:end]
337338
result_text.append(tokenizer.decode(result))
338339
return result_text
339340

@@ -509,10 +510,10 @@ def _convert_omni_to_inputs(
509510
spk_bounds_list = []
510511

511512
for index, text in enumerate(texts):
512-
text_chunks = re.split(self.split_pattern, text)
513+
text_chunks = re.split(self.tokenizer.split_pattern, text)
513514

514-
image_tags = re.findall(self.image_pattern, text)
515-
audio_tags = re.findall(self.audio_pattern, text)
515+
image_tags = re.findall(self.tokenizer.image_pattern, text)
516+
audio_tags = re.findall(self.tokenizer.audio_pattern, text)
516517

517518
if image_tags:
518519
assert images is not None
@@ -524,13 +525,13 @@ def _convert_omni_to_inputs(
524525
image_id = 0
525526
audio_id = 0
526527
for i, chunk in enumerate(text_chunks):
527-
if chunk == self.image_tag:
528+
if chunk == self.tokenizer.image_tag:
528529
image_placeholder = self.image_processor.get_slice_image_placeholder(
529530
self.tokenizer, image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
530531
)
531532
image_id += 1
532533
text_chunks[i] = image_placeholder
533-
elif chunk == self.audio_tag:
534+
elif chunk == self.tokenizer.audio_tag:
534535
audio_placeholder = audio_phs[index][audio_id]
535536
audio_id += 1
536537
text_chunks[i] = audio_placeholder

src/transformers/models/minicpm_o_2_6/tokenization_minicpm_o_2_6_fast.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from transformers import Qwen2TokenizerFast
1717

18+
1819
class MiniCPM_o_2_6TokenizerFast(Qwen2TokenizerFast):
1920
def __init__(self, **kwargs):
2021
super().__init__(**kwargs)
@@ -31,6 +32,8 @@ def __init__(self, **kwargs):
3132
self.slice_end = "</slice>"
3233
self.im_id_start = "<image_id>"
3334
self.im_id_end = "</image_id>"
35+
self.image_tag = f"({self.im_start}./{self.im_end})"
36+
self.image_pattern = "\(<image>./</image>\)"
3437

3538
# audio
3639
self.audio_start = "<|audio_start|>"
@@ -40,6 +43,12 @@ def __init__(self, **kwargs):
4043
self.tts_start = "<|tts_bos|>"
4144
self.tts_end = "<|tts_eos|>"
4245
self.unk_token = "<unk>"
46+
self.audio_tag = "(<audio>./</audio>)"
47+
self.audio_pattern = "\(<audio>./</audio>\)"
48+
49+
self.split_pattern = f"({self.image_pattern}|{self.audio_pattern})"
50+
51+
self.terminator_tokens = ["<|im_end|>", "<|endoftext|>", self.tts_end]
4352

4453
@property
4554
def eos_id(self):
@@ -53,6 +62,10 @@ def bos_id(self):
5362
def unk_id(self):
5463
return self.unk_token_id
5564

65+
@property
66+
def terminators(self):
67+
return self.terminator_tokens
68+
5669
@property
5770
def im_start_id(self):
5871
return self.convert_tokens_to_ids(self.im_start)
@@ -101,6 +114,10 @@ def tts_start_id(self):
101114
def tts_end_id(self):
102115
return self.convert_tokens_to_ids(self.tts_end)
103116

117+
@property
118+
def terminator_ids(self):
119+
return [self.convert_tokens_to_ids(t) for t in self.terminator_tokens]
120+
104121
@staticmethod
105122
def escape(text: str) -> str:
106123
return text

tests/models/minicpm_o_2_6/test.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)