Skip to content

Commit ae12902

Browse files
committed
fix code step2
1 parent ad61082 commit ae12902

File tree

3 files changed

+28
-33
lines changed

3 files changed

+28
-33
lines changed

src/transformers/models/minicpm_o_2_6/configuration_minicpm_o_2_6.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
5555
return cls.from_dict(config_dict, **kwargs)
5656

5757

58-
class ConditionalChatTTSConfig(PretrainedConfig):
58+
class MiniCPMConditionalTTSConfig(PretrainedConfig):
5959
model_type = "conditional_chattts"
6060

6161
def __init__(
@@ -195,10 +195,10 @@ def __init__(
195195
self.audio_config = audio_config
196196

197197
if tts_config is None:
198-
self.tts_config = ConditionalChatTTSConfig()
198+
self.tts_config = MiniCPMConditionalTTSConfig()
199199
elif isinstance(tts_config, dict):
200-
self.tts_config = ConditionalChatTTSConfig(**tts_config)
201-
elif isinstance(tts_config, ConditionalChatTTSConfig):
200+
self.tts_config = MiniCPMConditionalTTSConfig(**tts_config)
201+
elif isinstance(tts_config, MiniCPMConditionalTTSConfig):
202202
self.tts_config = tts_config
203203

204204
self.patch_size = self.vision_config.patch_size

src/transformers/models/minicpm_o_2_6/modeling_minicpm_o_2_6.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
except:
6464
_tts_deps = False
6565

66-
from .configuration_minicpm_o_2_6 import ConditionalChatTTSConfig, MiniCPM_o_2_6Config
66+
from .configuration_minicpm_o_2_6 import MiniCPMConditionalTTSConfig, MiniCPM_o_2_6Config
6767
from .processing_minicpm_o_2_6 import NumberToTextConverter, sentence_end, VoiceChecker, MiniCPM_o_2_6Processor
6868

6969
logger = logging.get_logger(__name__)
@@ -2661,10 +2661,10 @@ class ConditionalChatTTS(PreTrainedModel):
26612661
5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
26622662
"""
26632663

2664-
config_class = ConditionalChatTTSConfig
2664+
config_class = MiniCPMConditionalTTSConfig
26652665
_no_split_modules = []
26662666

2667-
def __init__(self, config: ConditionalChatTTSConfig):
2667+
def __init__(self, config: MiniCPMConditionalTTSConfig):
26682668
super().__init__(config)
26692669

26702670
self.use_speaker_embedding = config.use_speaker_embedding

src/transformers/models/minicpm_o_2_6/modular_minicpm_o_2_6.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,18 @@
3838
from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_
3939
from torch.nn.modules.activation import *
4040
from torch.nn.utils.parametrizations import weight_norm
41-
from transformers.utils import logging
41+
4242
from huggingface_hub import hf_hub_download
4343

4444
from transformers import AutoProcessor, BertTokenizerFast, LlamaConfig, LlamaModel, PreTrainedModel, Qwen2ForCausalLM, Qwen2PreTrainedModel, TextIteratorStreamer
4545
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
46-
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, replace_return_docstrings
46+
from transformers.utils import ModelOutput, logging, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, replace_return_docstrings
4747
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
4848
from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
4949
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
5050
from transformers.models.siglip.modeling_siglip import SIGLIP_START_DOCSTRING, SIGLIP_VISION_INPUTS_DOCSTRING, SiglipEncoderLayer, SiglipPreTrainedModel
5151
from transformers.models.idefics2.modeling_idefics2 import Idefics2Encoder
52-
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder
53-
from transformers.activations import ACT2FN
52+
from transformers.models.whisper.modeling_whisper import WhisperConfig, WhisperEncoder, WhisperEncoderLayer
5453
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
5554
from transformers.integrations import is_deepspeed_zero3_enabled
5655

@@ -63,13 +62,27 @@
6362
except:
6463
_tts_deps = False
6564

66-
from .configuration_minicpm_o_2_6 import ConditionalChatTTSConfig, MiniCPM_o_2_6Config
65+
from .configuration_minicpm_o_2_6 import MiniCPMConditionalTTSConfig, MiniCPM_o_2_6Config
6766
from .processing_minicpm_o_2_6 import NumberToTextConverter, sentence_end, VoiceChecker, MiniCPM_o_2_6Processor
6867

6968
logger = logging.get_logger(__name__)
7069

7170
@dataclass
7271
class OmniOutput(ModelOutput):
72+
"""
73+
Output class for the unified multimodal model (OmniOutput).
74+
This class is used to encapsulate the output of the model, which may include text, speaker embeddings, audio waveform, and sampling rate.
75+
76+
Attributes:
77+
text (Optional[Union[str, List[str], Iterator]]):
78+
The generated text output. It can be a single string, a list of strings (for batch output), or an iterator (for streaming output).
79+
spk_embeds (Optional[torch.FloatTensor]):
80+
The speaker embedding tensor, typically used for voice cloning or speaker adaptation. Shape: (num_spk_emb, hidden_dim).
81+
audio_wav (Optional[np.ndarray]):
82+
The generated audio waveform as a numpy array. This is the raw audio data (e.g., after vocoder decoding).
83+
sampling_rate (Optional[int]):
84+
The sampling rate (Hz) of the generated audio waveform. For example, 24000 or 16000.
85+
"""
7386
text: Optional[Union[str, List[str], Iterator]] = None
7487
spk_embeds: Optional[torch.FloatTensor] = None
7588
audio_wav: Optional[np.ndarray] = None
@@ -1867,25 +1880,7 @@ def decode_mel_to_audio(self, mel_spec, output_path=""):
18671880

18681881

18691882
# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
1870-
class MiniCPMWhisperEncoderLayer(nn.Module):
1871-
def __init__(self, config: WhisperConfig, layer_idx: int = None):
1872-
super().__init__()
1873-
self.embed_dim = config.d_model
1874-
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
1875-
embed_dim=self.embed_dim,
1876-
num_heads=config.encoder_attention_heads,
1877-
dropout=config.attention_dropout,
1878-
config=config,
1879-
layer_idx=layer_idx,
1880-
)
1881-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1882-
self.dropout = config.dropout
1883-
self.activation_fn = ACT2FN[config.activation_function]
1884-
self.activation_dropout = config.activation_dropout
1885-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
1886-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
1887-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
1888-
1883+
class MiniCPMWhisperEncoderLayer(WhisperEncoderLayer):
18891884
def forward(
18901885
self,
18911886
hidden_states: torch.Tensor,
@@ -2660,10 +2655,10 @@ class ConditionalChatTTS(PreTrainedModel):
26602655
5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
26612656
"""
26622657

2663-
config_class = ConditionalChatTTSConfig
2658+
config_class = MiniCPMConditionalTTSConfig
26642659
_no_split_modules = []
26652660

2666-
def __init__(self, config: ConditionalChatTTSConfig):
2661+
def __init__(self, config: MiniCPMConditionalTTSConfig):
26672662
super().__init__(config)
26682663

26692664
self.use_speaker_embedding = config.use_speaker_embedding

0 commit comments

Comments
 (0)