Skip to content

Commit cbfcaa1

Browse files
committed
add normaliser
1 parent f6726a7 commit cbfcaa1

File tree

7 files changed

+426
-52
lines changed

7 files changed

+426
-52
lines changed

src/f5_tts/infer/infer_gradio.py

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# ruff: noqa: E402
2-
# Above allows ruff to ignore E402: module level import not at top of file
3-
42
import json
53
import re
64
import tempfile
5+
import os # szükséges a normaliser mappák beolvasásához
76
from collections import OrderedDict
87
from importlib.resources import files
98

@@ -17,7 +16,6 @@
1716

1817
try:
1918
import spaces
20-
2119
USING_SPACES = True
2220
except ImportError:
2321
USING_SPACES = False
@@ -40,7 +38,6 @@ def gpu_decorator(func):
4038
save_spectrogram,
4139
)
4240

43-
4441
DEFAULT_TTS_MODEL = "F5-TTS_v1"
4542
tts_model_choice = DEFAULT_TTS_MODEL
4643

@@ -51,8 +48,24 @@ def gpu_decorator(func):
5148
]
5249

5350

54-
# load models
51+
# Függvény a ./normalisers mappában lévő almappák beolvasásához,
52+
# csak azokat veszi figyelembe, amelyek tartalmazzák a normaliser.py fájlt.
53+
def get_normaliser_choices():
54+
normaliser_path = "./normalisers"
55+
choices = []
56+
if os.path.exists(normaliser_path) and os.path.isdir(normaliser_path):
57+
for item in os.listdir(normaliser_path):
58+
subdir = os.path.join(normaliser_path, item)
59+
if os.path.isdir(subdir) and os.path.exists(os.path.join(subdir, "normaliser.py")):
60+
choices.append(item)
61+
return choices
62+
63+
normaliser_choices = get_normaliser_choices()
5564

65+
# A TTS modell választásánál csak az alap modellek szerepelnek
66+
all_tts_choices = [DEFAULT_TTS_MODEL, "E2-TTS", "Custom"]
67+
68+
# load models
5669
vocoder = load_vocoder()
5770

5871

@@ -186,11 +199,23 @@ def infer(
186199
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
187200
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
188201
""")
202+
203+
189204
with gr.Blocks() as app_tts:
190205
gr.Markdown("# Batched TTS")
191206
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
192207
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
208+
209+
# Új: Normaliser választás komponens (None esetén nem történik normalizálás)
210+
with gr.Row():
211+
choose_normaliser = gr.Radio(
212+
choices=["None"] + get_normaliser_choices(),
213+
label="Choose Normaliser",
214+
value="None",
215+
)
216+
193217
generate_btn = gr.Button("Synthesize", variant="primary")
218+
194219
with gr.Accordion("Advanced Settings", open=False):
195220
ref_text_input = gr.Textbox(
196221
label="Reference Text",
@@ -239,12 +264,34 @@ def basic_tts(
239264
cross_fade_duration_slider,
240265
nfe_slider,
241266
speed_slider,
267+
normaliser_choice_input, # A normaliser választás értéke
242268
):
269+
# Ha a normaliser választás nem "None", a normaliser.py-ból a normalize függvényt hívjuk meg
270+
if normaliser_choice_input != "None":
271+
normaliser_file = os.path.join("normalisers", normaliser_choice_input, "normaliser.py")
272+
if os.path.exists(normaliser_file):
273+
import importlib.util
274+
spec = importlib.util.spec_from_file_location("normaliser", normaliser_file)
275+
normaliser_module = importlib.util.module_from_spec(spec)
276+
spec.loader.exec_module(normaliser_module)
277+
if hasattr(normaliser_module, "normalize"):
278+
processed_text = normaliser_module.normalize(gen_text_input)
279+
else:
280+
print("A normaliser.py nem tartalmazza a 'normalize' függvényt, ezért az eredeti szöveg kerül használatra.")
281+
processed_text = gen_text_input
282+
else:
283+
print("A megadott normaliser.py fájl nem található, ezért az eredeti szöveg kerül használatra.")
284+
processed_text = gen_text_input
285+
else:
286+
processed_text = gen_text_input
287+
288+
# A TTS modell választása a globális tts_model_choice változóból történik
289+
actual_model = tts_model_choice
243290
audio_out, spectrogram_path, ref_text_out = infer(
244291
ref_audio_input,
245292
ref_text_input,
246-
gen_text_input,
247-
tts_model_choice,
293+
processed_text,
294+
actual_model,
248295
remove_silence,
249296
cross_fade_duration=cross_fade_duration_slider,
250297
nfe_step=nfe_slider,
@@ -262,6 +309,7 @@ def basic_tts(
262309
cross_fade_duration_slider,
263310
nfe_slider,
264311
speed_slider,
312+
choose_normaliser # itt adjuk át a normaliser választás értékét
265313
],
266314
outputs=[audio_output, spectrogram_output, ref_text_input],
267315
)
@@ -477,7 +525,7 @@ def generate_multistyle_speech(
477525
# Generate speech for this segment
478526
audio_out, _, ref_text_out = infer(
479527
ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
480-
) # show_info=print no pull to top when generating
528+
)
481529
sr, audio_data = audio_out
482530

483531
generated_audio_segments.append(audio_data)
@@ -525,10 +573,8 @@ def validate_speech_types(gen_text, regular_name, *args):
525573
missing_speech_types = speech_types_in_text - speech_types_available
526574

527575
if missing_speech_types:
528-
# Disable the generate button
529576
return gr.update(interactive=False)
530577
else:
531-
# Enable the generate button
532578
return gr.update(interactive=True)
533579

534580
gen_text_input_multistyle.change(
@@ -552,7 +598,6 @@ def validate_speech_types(gen_text, regular_name, *args):
552598

553599
if not USING_SPACES:
554600
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
555-
556601
chat_interface_container = gr.Column(visible=False)
557602

558603
@gpu_decorator
@@ -567,14 +612,11 @@ def load_chat_model():
567612
)
568613
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
569614
show_info("Chat model loaded.")
570-
571615
return gr.update(visible=False), gr.update(visible=True)
572616

573617
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
574-
575618
else:
576619
chat_interface_container = gr.Column()
577-
578620
if chat_model_state is None:
579621
model_name = "Qwen/Qwen2.5-3B-Instruct"
580622
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
@@ -600,9 +642,7 @@ def load_chat_model():
600642
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
601643
lines=2,
602644
)
603-
604645
chatbot_interface = gr.Chatbot(label="Conversation")
605-
606646
with gr.Row():
607647
with gr.Column():
608648
audio_input_chat = gr.Microphone(
@@ -617,7 +657,6 @@ def load_chat_model():
617657
)
618658
send_btn_chat = gr.Button("Send Message")
619659
clear_btn_chat = gr.Button("Clear Conversation")
620-
621660
conversation_state = gr.State(
622661
value=[
623662
{
@@ -627,40 +666,28 @@ def load_chat_model():
627666
]
628667
)
629668

630-
# Modify process_audio_input to use model and tokenizer from state
631669
@gpu_decorator
632670
def process_audio_input(audio_path, text, history, conv_state):
633-
"""Handle audio or text input from user"""
634-
635671
if not audio_path and not text.strip():
636672
return history, conv_state, ""
637-
638673
if audio_path:
639674
text = preprocess_ref_audio_text(audio_path, text)[1]
640-
641675
if not text.strip():
642676
return history, conv_state, ""
643-
644677
conv_state.append({"role": "user", "content": text})
645678
history.append((text, None))
646-
647679
response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
648-
649680
conv_state.append({"role": "assistant", "content": response})
650681
history[-1] = (text, response)
651-
652682
return history, conv_state, ""
653683

654684
@gpu_decorator
655685
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
656-
"""Generate TTS audio for AI response"""
657686
if not history or not ref_audio:
658687
return None
659-
660688
last_user_message, last_ai_response = history[-1]
661689
if not last_ai_response:
662690
return None
663-
664691
audio_result, _, ref_text_out = infer(
665692
ref_audio,
666693
ref_text,
@@ -669,12 +696,11 @@ def generate_audio_response(history, ref_audio, ref_text, remove_silence):
669696
remove_silence,
670697
cross_fade_duration=0.15,
671698
speed=1.0,
672-
show_info=print, # show_info=print no pull to top when generating
699+
show_info=print,
673700
)
674701
return audio_result, ref_text_out
675702

676703
def clear_conversation():
677-
"""Reset the conversation"""
678704
return [], [
679705
{
680706
"role": "system",
@@ -683,11 +709,9 @@ def clear_conversation():
683709
]
684710

685711
def update_system_prompt(new_prompt):
686-
"""Update the system prompt and reset the conversation"""
687712
new_conv_state = [{"role": "system", "content": new_prompt}]
688713
return [], new_conv_state
689714

690-
# Handle audio input
691715
audio_input_chat.stop_recording(
692716
process_audio_input,
693717
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
@@ -701,8 +725,6 @@ def update_system_prompt(new_prompt):
701725
None,
702726
audio_input_chat,
703727
)
704-
705-
# Handle text input
706728
text_input_chat.submit(
707729
process_audio_input,
708730
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
@@ -716,8 +738,6 @@ def update_system_prompt(new_prompt):
716738
None,
717739
text_input_chat,
718740
)
719-
720-
# Handle send button
721741
send_btn_chat.click(
722742
process_audio_input,
723743
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
@@ -731,14 +751,10 @@ def update_system_prompt(new_prompt):
731751
None,
732752
text_input_chat,
733753
)
734-
735-
# Handle clear button
736754
clear_btn_chat.click(
737755
clear_conversation,
738756
outputs=[chatbot_interface, conversation_state],
739757
)
740-
741-
# Handle system prompt change and reset conversation
742758
system_prompt_chat.change(
743759
update_system_prompt,
744760
inputs=system_prompt_chat,
@@ -779,14 +795,12 @@ def load_last_used_custom():
779795

780796
def switch_tts_model(new_choice):
781797
global tts_model_choice
782-
if new_choice == "Custom": # override in case webpage is refreshed
783-
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
784-
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
785-
return (
786-
gr.update(visible=True, value=custom_ckpt_path),
787-
gr.update(visible=True, value=custom_vocab_path),
788-
gr.update(visible=True, value=custom_model_cfg),
789-
)
798+
if new_choice == "Custom":
799+
custom_ckpt_path_update = gr.update(visible=True, value=load_last_used_custom()[0])
800+
custom_vocab_path_update = gr.update(visible=True, value=load_last_used_custom()[1])
801+
custom_model_cfg_update = gr.update(visible=True, value=load_last_used_custom()[2])
802+
tts_model_choice = ["Custom", load_last_used_custom()[0], load_last_used_custom()[1], json.loads(load_last_used_custom()[2])]
803+
return custom_ckpt_path_update, custom_vocab_path_update, custom_model_cfg_update
790804
else:
791805
tts_model_choice = new_choice
792806
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
@@ -800,11 +814,15 @@ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
800814
with gr.Row():
801815
if not USING_SPACES:
802816
choose_tts_model = gr.Radio(
803-
choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
817+
choices=all_tts_choices,
818+
label="Choose TTS Model",
819+
value=DEFAULT_TTS_MODEL,
804820
)
805821
else:
806822
choose_tts_model = gr.Radio(
807-
choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
823+
choices=[DEFAULT_TTS_MODEL, "E2-TTS"],
824+
label="Choose TTS Model",
825+
value=DEFAULT_TTS_MODEL,
808826
)
809827
custom_ckpt_path = gr.Dropdown(
810828
choices=[DEFAULT_TTS_MODEL_CFG[0]],
@@ -853,7 +871,6 @@ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
853871
label="Config: in a dictionary form",
854872
visible=False,
855873
)
856-
857874
choose_tts_model.change(
858875
switch_tts_model,
859876
inputs=[choose_tts_model],
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
Kyle, kájl
2+
Stan, sten
3+
Chartmen, kártmen
4+
Keny, keni
5+
dr, doktor
6+
dr., doktor
7+
mr, miszter
8+
mr., miszter
9+
riose, riosz
10+
DNS, déenes
11+
Hari, henri
12+
gaal, geel
13+
USS, júeszesz
14+
Cerritos, Kerritosz
15+
Sylvia, Szilvia
16+
FNN, efenen
17+
Planet, plenet
18+
Backett, Beket
19+
Gate, gét
20+
Alonso, Alonzó
21+
Carol, Kerol
22+
Bradward, bredvörd
23+
geneieve, dzsenív
24+
boimler, bojler
25+
one, egy
26+
two, kettő
27+
gpu, gépéú
28+
openai, openéáj
29+
deepseek, dípszík
30+
deepsek, dípszík
31+
v3, vé három
32+
LLM, elelem
33+
r1, er egy
34+
like, lájk
35+
likeot, lájkot
36+
Anthropic, antropik
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
%, százalék
2+
ninjutsu, nindzsucu
3+
tweet, tvít

0 commit comments

Comments
 (0)