|
38 | 38 | from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_
|
39 | 39 | from torch.nn.modules.activation import *
|
40 | 40 | from torch.nn.utils.parametrizations import weight_norm
|
41 |
| -from transformers.utils import logging |
| 41 | + |
42 | 42 | from huggingface_hub import hf_hub_download
|
43 | 43 |
|
44 | 44 | from transformers import AutoProcessor, BertTokenizerFast, LlamaConfig, LlamaModel, PreTrainedModel, Qwen2ForCausalLM, Qwen2PreTrainedModel, TextIteratorStreamer
|
45 | 45 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
|
46 |
| -from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, replace_return_docstrings |
| 46 | +from transformers.utils import ModelOutput, logging, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, replace_return_docstrings |
47 | 47 | from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
48 | 48 | from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
|
49 | 49 | from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
50 | 50 | from transformers.models.siglip.modeling_siglip import SIGLIP_START_DOCSTRING, SIGLIP_VISION_INPUTS_DOCSTRING, SiglipEncoderLayer, SiglipPreTrainedModel
|
51 | 51 | from transformers.models.idefics2.modeling_idefics2 import Idefics2Encoder
|
52 |
| -from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder |
53 |
| -from transformers.activations import ACT2FN |
| 52 | +from transformers.models.whisper.modeling_whisper import WhisperConfig, WhisperEncoder, WhisperEncoderLayer |
54 | 53 | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
55 | 54 | from transformers.integrations import is_deepspeed_zero3_enabled
|
56 | 55 |
|
|
63 | 62 | except:
|
64 | 63 | _tts_deps = False
|
65 | 64 |
|
66 |
| -from .configuration_minicpm_o_2_6 import ConditionalChatTTSConfig, MiniCPM_o_2_6Config |
| 65 | +from .configuration_minicpm_o_2_6 import MiniCPMConditionalTTSConfig, MiniCPM_o_2_6Config |
67 | 66 | from .processing_minicpm_o_2_6 import NumberToTextConverter, sentence_end, VoiceChecker, MiniCPM_o_2_6Processor
|
68 | 67 |
|
69 | 68 | logger = logging.get_logger(__name__)
|
70 | 69 |
|
71 | 70 | @dataclass
|
72 | 71 | class OmniOutput(ModelOutput):
|
| 72 | + """ |
| 73 | + Output class for the unified multimodal model (OmniOutput). |
| 74 | + This class is used to encapsulate the output of the model, which may include text, speaker embeddings, audio waveform, and sampling rate. |
| 75 | +
|
| 76 | + Attributes: |
| 77 | + text (Optional[Union[str, List[str], Iterator]]): |
| 78 | + The generated text output. It can be a single string, a list of strings (for batch output), or an iterator (for streaming output). |
| 79 | + spk_embeds (Optional[torch.FloatTensor]): |
| 80 | + The speaker embedding tensor, typically used for voice cloning or speaker adaptation. Shape: (num_spk_emb, hidden_dim). |
| 81 | + audio_wav (Optional[np.ndarray]): |
| 82 | + The generated audio waveform as a numpy array. This is the raw audio data (e.g., after vocoder decoding). |
| 83 | + sampling_rate (Optional[int]): |
| 84 | + The sampling rate (Hz) of the generated audio waveform. For example, 24000 or 16000. |
| 85 | + """ |
73 | 86 | text: Optional[Union[str, List[str], Iterator]] = None
|
74 | 87 | spk_embeds: Optional[torch.FloatTensor] = None
|
75 | 88 | audio_wav: Optional[np.ndarray] = None
|
@@ -1867,25 +1880,7 @@ def decode_mel_to_audio(self, mel_spec, output_path=""):
|
1867 | 1880 |
|
1868 | 1881 |
|
1869 | 1882 | # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
|
1870 |
| -class MiniCPMWhisperEncoderLayer(nn.Module): |
1871 |
| - def __init__(self, config: WhisperConfig, layer_idx: int = None): |
1872 |
| - super().__init__() |
1873 |
| - self.embed_dim = config.d_model |
1874 |
| - self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( |
1875 |
| - embed_dim=self.embed_dim, |
1876 |
| - num_heads=config.encoder_attention_heads, |
1877 |
| - dropout=config.attention_dropout, |
1878 |
| - config=config, |
1879 |
| - layer_idx=layer_idx, |
1880 |
| - ) |
1881 |
| - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
1882 |
| - self.dropout = config.dropout |
1883 |
| - self.activation_fn = ACT2FN[config.activation_function] |
1884 |
| - self.activation_dropout = config.activation_dropout |
1885 |
| - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) |
1886 |
| - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) |
1887 |
| - self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
1888 |
| - |
| 1883 | +class MiniCPMWhisperEncoderLayer(WhisperEncoderLayer): |
1889 | 1884 | def forward(
|
1890 | 1885 | self,
|
1891 | 1886 | hidden_states: torch.Tensor,
|
@@ -2660,10 +2655,10 @@ class ConditionalChatTTS(PreTrainedModel):
|
2660 | 2655 | 5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
|
2661 | 2656 | """
|
2662 | 2657 |
|
2663 |
| - config_class = ConditionalChatTTSConfig |
| 2658 | + config_class = MiniCPMConditionalTTSConfig |
2664 | 2659 | _no_split_modules = []
|
2665 | 2660 |
|
2666 |
| - def __init__(self, config: ConditionalChatTTSConfig): |
| 2661 | + def __init__(self, config: MiniCPMConditionalTTSConfig): |
2667 | 2662 | super().__init__(config)
|
2668 | 2663 |
|
2669 | 2664 | self.use_speaker_embedding = config.use_speaker_embedding
|
|
0 commit comments