Skip to content

Commit 1a5101b

Browse files
authored
Merge pull request #4 from Guo-Chenxu/minicpm_o_2_6
resolve comments
2 parents d98d2ad + 2678d25 commit 1a5101b

File tree

6 files changed

+1672
-1292
lines changed

6 files changed

+1672
-1292
lines changed

src/transformers/models/minicpm_o_2_6/configuration_minicpm_o_2_6.py

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
import os
1717
from typing import Union
1818

19-
from ...configuration_utils import PretrainedConfig
19+
from ...configuration_utils import PretrainedConfig, layer_type_validation
20+
from ...modeling_rope_utils import rope_config_validation
2021
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
2122
from transformers import Qwen2Config, WhisperConfig
2223
from ...utils import logging
2324

2425
logger = logging.get_logger(__name__)
26+
27+
2528
class MiniCPMVSliceConfig(PretrainedConfig):
2629
model_type = "minicpmv"
2730

@@ -39,9 +42,8 @@ def __init__(
3942

4043
@classmethod
4144
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
42-
cls._set_token_in_kwargs(kwargs)
43-
44-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
45+
config_dict, kwargs = cls.get_config_dict(
46+
pretrained_model_name_or_path, **kwargs)
4547

4648
if config_dict.get("model_type") == "minicpmv":
4749
config_dict = config_dict["slice_config"]
@@ -84,10 +86,6 @@ def __init__(
8486
attn_implementation: str = "sdpa",
8587
use_mlp: bool = True,
8688
aug_loss_weight: bool = True,
87-
do_sample: bool = True,
88-
top_p: float = 0.7,
89-
top_k: int = 20,
90-
repetition_penalty: float = 1.0,
9189
**kwargs,
9290
):
9391
super().__init__(**kwargs)
@@ -116,13 +114,9 @@ def __init__(
116114
self.attn_implementation = attn_implementation
117115
self.use_mlp = use_mlp
118116
self.aug_loss_weight = aug_loss_weight
119-
self.do_sample = do_sample
120-
self.top_p = top_p
121-
self.top_k = top_k
122-
self.repetition_penalty = repetition_penalty
123117

124118

125-
class MiniCPM_o_2_6Config(Qwen2Config):
119+
class MiniCPM_o_2_6Config(PretrainedConfig):
126120
model_type = "minicpmo"
127121
keys_to_ignore_at_inference = ["past_key_values"]
128122

@@ -136,6 +130,21 @@ class MiniCPM_o_2_6Config(Qwen2Config):
136130
"patch_size": 14,
137131
}
138132

133+
base_model_tp_plan = {
134+
"layers.*.self_attn.q_proj": "colwise",
135+
"layers.*.self_attn.k_proj": "colwise",
136+
"layers.*.self_attn.v_proj": "colwise",
137+
"layers.*.self_attn.o_proj": "rowwise",
138+
"layers.*.mlp.gate_proj": "colwise",
139+
"layers.*.mlp.up_proj": "colwise",
140+
"layers.*.mlp.down_proj": "rowwise",
141+
}
142+
base_model_pp_plan = {
143+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
144+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
145+
"norm": (["hidden_states"], ["hidden_states"]),
146+
}
147+
139148
def __init__(
140149
self,
141150
use_cache=True,
@@ -155,6 +164,24 @@ def __init__(
155164
init_vision=True,
156165
init_audio=True,
157166
init_tts=True,
167+
vocab_size=151936,
168+
hidden_size=4096,
169+
intermediate_size=22016,
170+
num_hidden_layers=32,
171+
num_attention_heads=32,
172+
num_key_value_heads=32,
173+
hidden_act="silu",
174+
max_position_embeddings=32768,
175+
initializer_range=0.02,
176+
rms_norm_eps=1e-6,
177+
tie_word_embeddings=False,
178+
rope_theta=10000.0,
179+
rope_scaling=None,
180+
use_sliding_window=False,
181+
sliding_window=4096,
182+
max_window_layers=28,
183+
layer_types=None,
184+
attention_dropout=0.0,
158185
**kwargs,
159186
):
160187
self.use_cache = use_cache
@@ -179,7 +206,8 @@ def __init__(
179206

180207
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
181208
if vision_config is None:
182-
self.vision_config = SiglipVisionConfig(**self.default_vision_config)
209+
self.vision_config = SiglipVisionConfig(
210+
**self.default_vision_config)
183211
logger.info("vision_config is None, using default vision config")
184212
elif isinstance(vision_config, dict):
185213
self.vision_config = SiglipVisionConfig(**vision_config)
@@ -203,7 +231,47 @@ def __init__(
203231

204232
self.patch_size = self.vision_config.patch_size
205233

206-
super().__init__(**kwargs)
234+
self.vocab_size = vocab_size
235+
self.max_position_embeddings = max_position_embeddings
236+
self.hidden_size = hidden_size
237+
self.intermediate_size = intermediate_size
238+
self.num_hidden_layers = num_hidden_layers
239+
self.num_attention_heads = num_attention_heads
240+
self.use_sliding_window = use_sliding_window
241+
self.sliding_window = sliding_window if self.use_sliding_window else None
242+
self.max_window_layers = max_window_layers
243+
244+
# for backward compatibility
245+
if num_key_value_heads is None:
246+
num_key_value_heads = num_attention_heads
247+
248+
self.num_key_value_heads = num_key_value_heads
249+
self.hidden_act = hidden_act
250+
self.initializer_range = initializer_range
251+
self.rms_norm_eps = rms_norm_eps
252+
self.rope_theta = rope_theta
253+
self.rope_scaling = rope_scaling
254+
self.attention_dropout = attention_dropout
255+
# Validate the correctness of rotary position embeddings parameters
256+
# BC: if there is a 'type' field, move it to 'rope_type'.
257+
if self.rope_scaling is not None and "type" in self.rope_scaling:
258+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
259+
rope_config_validation(self)
260+
261+
self.layer_types = layer_types
262+
if self.layer_types is None:
263+
self.layer_types = [
264+
"sliding_attention"
265+
if self.sliding_window is not None and i >= self.max_window_layers
266+
else "full_attention"
267+
for i in range(self.num_hidden_layers)
268+
]
269+
layer_type_validation(self.layer_types)
270+
271+
super().__init__(
272+
tie_word_embeddings=tie_word_embeddings,
273+
**kwargs,
274+
)
207275

208276

209277
__all__ = ["MiniCPM_o_2_6Config"]
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# coding=utf-8
2+
# Copyright 2025 The OpenBMB Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import math
17+
from typing import List, Optional, Union
18+
19+
from transformers import WhisperFeatureExtractor, AutoFeatureExtractor, AutoTokenizer
20+
import numpy as np
21+
import torch
22+
23+
24+
class MiniCPM_o_2_6FeatureExtractor(WhisperFeatureExtractor):
25+
26+
def __init__(self, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
29+
def __call__(
30+
self,
31+
tokenizer: None,
32+
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]],
33+
audio_parts: Optional[list] = None,
34+
chunk_input: Optional[bool] = False,
35+
sampling_rate: Optional[int] = None,
36+
chunk_length: Optional[int] = 1,
37+
**kwargs,
38+
):
39+
if isinstance(audios, np.ndarray):
40+
audios_list = [[audios]]
41+
elif isinstance(audios[0], np.ndarray):
42+
audios_list = [audios]
43+
else:
44+
audios_list = audios
45+
46+
if audio_parts is not None:
47+
assert len(audio_parts) == len(audios_list)
48+
for parts, audios in zip(audio_parts, audios_list):
49+
assert len(parts) == len(audios)
50+
51+
audio_feature_lens_list = []
52+
audio_ph_list = []
53+
54+
audio_features_all = []
55+
56+
# audio placeholder not dependent on audio_parts
57+
for audios in audios_list:
58+
if audios:
59+
audio_ph_list.append([self.get_audio_placeholder(tokenizer,
60+
len(a), chunk_input, chunk_length) for a in audios])
61+
else:
62+
audio_ph_list.append([])
63+
64+
for idx, audios in enumerate(audios_list):
65+
if audio_parts is not None:
66+
# same audio part merge
67+
audio_part = audio_parts[idx]
68+
merge_audio = []
69+
cur_audio = []
70+
for aid, (part, audio) in enumerate(zip(audio_part, audios)):
71+
if aid == 0 or audio_part[aid] == audio_part[aid - 1]:
72+
cur_audio.append(audio)
73+
else:
74+
merge_audio.append(np.hstack(cur_audio))
75+
cur_audio = [audio]
76+
if cur_audio:
77+
merge_audio.append(np.hstack(cur_audio))
78+
79+
else:
80+
merge_audio = audios
81+
82+
audio_feature_lens = []
83+
84+
# If the audio exceeds 30 seconds, split it into chunks every 30 seconds.
85+
final_merge_audio = []
86+
max_audio_inp_len = 30 * sampling_rate
87+
for audio in merge_audio:
88+
if len(audio) <= max_audio_inp_len:
89+
final_merge_audio.append(audio)
90+
else:
91+
for i in range(math.ceil(len(audio) / max_audio_inp_len)):
92+
final_merge_audio.append(
93+
audio[i * max_audio_inp_len: (i + 1) * max_audio_inp_len])
94+
95+
if audios:
96+
audio_inputs = super().__call__(
97+
final_merge_audio,
98+
sampling_rate=sampling_rate,
99+
return_attention_mask=True,
100+
padding="max_length",
101+
return_tensors="pt",
102+
**kwargs,
103+
)
104+
audio_feature = audio_inputs["input_features"]
105+
actual_lens = audio_inputs["attention_mask"].sum(dim=1)
106+
107+
for feat, lens in zip(audio_feature, actual_lens):
108+
audio_features_all.append(feat[:, :lens])
109+
audio_feature_lens.append(lens)
110+
111+
audio_feature_lens = torch.hstack(audio_feature_lens)
112+
audio_feature_lens_list.append(audio_feature_lens)
113+
else:
114+
audio_feature_lens_list.append([])
115+
116+
if audio_features_all:
117+
audio_features = [i.permute(1, 0) for i in audio_features_all]
118+
audio_features = torch.nn.utils.rnn.pad_sequence(
119+
audio_features, batch_first=True, padding_value=0.0
120+
).permute(0, 2, 1)
121+
else:
122+
audio_features = []
123+
124+
return audio_features, audio_feature_lens_list, audio_ph_list
125+
126+
def get_audio_placeholder(self, tokenizer, audio_lens, chunk_input, chunk_length):
127+
pool_step = 2
128+
feature_lens = math.ceil(
129+
audio_lens / self.hop_length)
130+
131+
feature_lens = (feature_lens - 1) // 2 + 1
132+
output_lens = (feature_lens - pool_step) // pool_step + 1
133+
134+
if chunk_input:
135+
fbank_feat_in_chunk = int(chunk_length * 100)
136+
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
137+
audio_embeds_in_chunk = (
138+
cnn_feat_in_chunk - pool_step) // pool_step + 1
139+
num_audio_chunks = (
140+
output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk
141+
142+
place_holders = ""
143+
total_unk_len = 0
144+
for _ in range(num_audio_chunks):
145+
unk_len = min(audio_embeds_in_chunk,
146+
output_lens - total_unk_len)
147+
place_holders += tokenizer.audio_start + \
148+
tokenizer.unk_token * unk_len + tokenizer.audio_end
149+
total_unk_len += unk_len
150+
audio_placeholder = place_holders
151+
else:
152+
audio_placeholder = tokenizer.audio_start + \
153+
tokenizer.unk_token * output_lens + tokenizer.audio_end
154+
155+
return audio_placeholder
156+
157+
158+
AutoFeatureExtractor.register(
159+
"MiniCPM_o_2_6FeatureExtractor", MiniCPM_o_2_6FeatureExtractor)
160+
161+
__all__ = ["MiniCPM_o_2_6FeatureExtractor"]

0 commit comments

Comments
 (0)