Skip to content

Commit ad7b1df

Browse files
committed
auto processor
Signed-off-by: guochenxu <[email protected]>
1 parent 9a433cd commit ad7b1df

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

src/transformers/models/minicpm_o_2_6/feature_extractor_minicpm_o_2_6.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, *args, **kwargs):
2828

2929
def __call__(
3030
self,
31+
tokenizer: None,
3132
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]],
3233
audio_parts: Optional[list] = None,
3334
chunk_input: Optional[bool] = False,
@@ -55,7 +56,7 @@ def __call__(
5556
# audio placeholder not dependent on audio_parts
5657
for audios in audios_list:
5758
if audios:
58-
audio_ph_list.append([self.get_audio_placeholder(
59+
audio_ph_list.append([self.get_audio_placeholder(tokenizer,
5960
len(a), chunk_input, chunk_length) for a in audios])
6061
else:
6162
audio_ph_list.append([])
@@ -122,7 +123,7 @@ def __call__(
122123

123124
return audio_features, audio_feature_lens_list, audio_ph_list
124125

125-
def get_audio_placeholder(self, audio_lens, chunk_input, chunk_length):
126+
def get_audio_placeholder(self, tokenizer, audio_lens, chunk_input, chunk_length):
126127
pool_step = 2
127128
feature_lens = math.ceil(
128129
audio_lens / self.hop_length)
@@ -143,13 +144,13 @@ def get_audio_placeholder(self, audio_lens, chunk_input, chunk_length):
143144
for _ in range(num_audio_chunks):
144145
unk_len = min(audio_embeds_in_chunk,
145146
output_lens - total_unk_len)
146-
place_holders += self.tokenizer.audio_start + \
147-
self.tokenizer.unk_token * unk_len + self.tokenizer.audio_end
147+
place_holders += tokenizer.audio_start + \
148+
tokenizer.unk_token * unk_len + tokenizer.audio_end
148149
total_unk_len += unk_len
149150
audio_placeholder = place_holders
150151
else:
151-
audio_placeholder = self.tokenizer.audio_start + \
152-
self.tokenizer.unk_token * output_lens + self.tokenizer.audio_end
152+
audio_placeholder = tokenizer.audio_start + \
153+
tokenizer.unk_token * output_lens + tokenizer.audio_end
153154

154155
return audio_placeholder
155156

src/transformers/models/minicpm_o_2_6/image_processing_minicpm.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=
135135

136136
return source_image, patches, best_grid
137137

138-
def get_grid_placeholder(self, grid):
138+
def get_grid_placeholder(self, tokenizer, grid):
139139
if grid is None:
140140
return ""
141141
slice_image_placeholder = (
142-
self.tokenizer.slice_start + self.tokenizer.unk_token *
143-
self.image_feature_size + self.tokenizer.slice_end
142+
tokenizer.slice_start + tokenizer.unk_token *
143+
self.image_feature_size + tokenizer.slice_end
144144
)
145145

146146
cols = grid[0]
@@ -155,8 +155,8 @@ def get_grid_placeholder(self, grid):
155155
slice_placeholder = "\n".join(slices)
156156
return slice_placeholder
157157

158-
def get_image_id_placeholder(self, idx=0):
159-
return f"{self.tokenizer.im_id_start}{idx}{self.tokenizer.im_id_end}"
158+
# def get_image_id_placeholder(self, idx=0):
159+
# return f"{self.tokenizer.im_id_start}{idx}{self.tokenizer.im_id_end}"
160160

161161
def get_sliced_images(self, image, max_slice_nums=None):
162162
slice_images = []
@@ -211,26 +211,25 @@ def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
211211

212212
return best_grid
213213

214-
def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
214+
def get_slice_image_placeholder(self, tokenizer, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
215215
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(
216216
max_slice_nums)
217217
assert max_slice_nums > 0
218218
grid = self.get_sliced_grid(
219219
image_size=image_size, max_slice_nums=max_slice_nums)
220220

221-
image_placeholder = self.tokenizer.im_start + self.tokenizer.unk_token * \
222-
self.image_feature_size + self.tokenizer.im_end
221+
image_placeholder = tokenizer.im_start + tokenizer.unk_token * \
222+
self.image_feature_size + tokenizer.im_end
223223
use_image_id = self.use_image_id if use_image_id is None else bool(
224224
use_image_id)
225225
if use_image_id:
226-
final_placeholder = self.get_image_id_placeholder(
227-
image_idx) + image_placeholder
226+
final_placeholder = f"{tokenizer.im_id_start}{image_idx}{tokenizer.im_id_end}" + image_placeholder
228227
else:
229228
final_placeholder = image_placeholder
230229

231230
if self.slice_mode:
232231
final_placeholder = final_placeholder + \
233-
self.get_grid_placeholder(grid=grid)
232+
self.get_grid_placeholder(tokenizer, grid=grid)
234233
return final_placeholder
235234

236235
def reshape_by_patch(self, image):

src/transformers/models/minicpm_o_2_6/modeling_minicpm_o_2_6.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -542,14 +542,12 @@ def __init__(self, config):
542542
assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed."
543543
self.tts = self.init_tts_module()
544544

545-
# self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
545+
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path)
546546

547-
tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
548-
image_processor = AutoImageProcessor.from_pretrained(config._name_or_path)
549-
image_processor.tokenizer = tokenizer
550-
feature_extractor = MiniCPM_o_2_6FeatureExtractor.from_pretrained(config._name_or_path)
551-
feature_extractor.tokenizer = tokenizer
552-
self.processor = MiniCPM_o_2_6Processor(image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer)
547+
# tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
548+
# image_processor = AutoImageProcessor.from_pretrained(config._name_or_path)
549+
# feature_extractor = MiniCPM_o_2_6FeatureExtractor.from_pretrained(config._name_or_path)
550+
# self.processor = MiniCPM_o_2_6Processor(image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer)
553551

554552
self.terminators = ["<|im_end|>", "<|endoftext|>"]
555553

@@ -3182,8 +3180,8 @@ def forward(
31823180
if position_ids is None:
31833181
position_ids = cache_position.unsqueeze(0)
31843182

3185-
# ! in transformers=4.53.1, this is create_causal_mask, but it is wrong in our case
3186-
# so copy _update_causal_mask from LlamaModel which transformers=4.44.2
3183+
# ! in transformers>=4.53.1, this is `create_causal_mask`, but it will be wrong in our case
3184+
# so copy `_update_causal_mask` from LlamaModel which transformers=4.44.2
31873185
causal_mask = self._update_causal_mask(
31883186
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
31893187
)

src/transformers/models/minicpm_o_2_6/processing_minicpm_o_2_6.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
150150
The tokenizer is a required input.
151151
"""
152152

153-
attributes = ["image_processor", "feature_extractor", "tokenizer"]
153+
attributes = ["tokenizer", "image_processor", "feature_extractor"]
154+
tokenizer_class = "AutoTokenizer"
154155
image_processor_class = "AutoImageProcessor"
155156
feature_extractor_class = "MiniCPM_o_2_6FeatureExtractor"
156-
tokenizer_class = "AutoTokenizer"
157157

158-
def __init__(self, image_processor=None, feature_extractor=None, tokenizer=None):
159-
super().__init__(image_processor, feature_extractor, tokenizer)
158+
def __init__(self, tokenizer=None, image_processor=None, feature_extractor=None):
159+
super().__init__(tokenizer, image_processor, feature_extractor)
160160
self.version = image_processor.version
161161
self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"
162162
self.image_tag = "(<image>./</image>)"
@@ -189,6 +189,7 @@ def __call__(
189189

190190
if audios:
191191
audio_features, audio_feature_lens, audio_phs = self.feature_extractor(
192+
self.tokenizer,
192193
audios,
193194
audio_parts=audio_kwargs["audio_parts"],
194195
chunk_input=audio_kwargs["chunk_input"],
@@ -437,7 +438,7 @@ def _convert_omni_to_inputs(
437438
for i, chunk in enumerate(text_chunks):
438439
if chunk == self.image_tag:
439440
image_placeholder = self.image_processor.get_slice_image_placeholder(
440-
image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
441+
self.tokenizer, image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
441442
)
442443
image_id += 1
443444
text_chunks[i] = image_placeholder

0 commit comments

Comments
 (0)