Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/src/inference/kokoro_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ async def generate(
]
):
continue
if not token.text or not token.text.strip():

# token.start_ts may be None
if not token.text or not token.text.strip() or token.start_ts is None or token.end_ts is None:
continue

start_time = float(token.start_ts) + current_offset
Expand Down
98 changes: 56 additions & 42 deletions api/src/services/text_processing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numpy import number
from torch import mul
from ...structures.schemas import NormalizationOptions
from misaki import en

from text_to_num import text2num

Expand Down Expand Up @@ -54,6 +55,7 @@
"uk",
"us",
"io",
"co"
]

VALID_UNITS = {
Expand Down Expand Up @@ -90,35 +92,47 @@

UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)

TIME_PATTERN = re.compile(r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
TIME_PATTERN = re.compile(r"([0-9]{1,2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)

INFLECT_ENGINE=inflect.engine()

def split_num(num: re.Match[str]) -> str:
def sound_like(text: str, sound_like: str, lang_code: str) -> str:
from .phonemizer import phonemize
"""
Convert a string into a sound-alike format

Kokoro supports embedding phonemes in the text, and the token timestamps is based on the original text.
- Original Input Text: '[Misaki](/misˈɑki/) is a G2P engine designed for [Kokoro](/kˈOkəɹO/) models.'
- Text For Timestamps: 'Misaki is a G2P engine designed for Kokoro models.'
"""
phonemes = phonemize(sound_like, language = lang_code, normalize = False)
return f"[{text}](/{phonemes}/)"

def split_num(num: re.Match[str], lang_code) -> str:
"""Handle number splitting for various formats"""
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
return sound_like(num, f"{h} o'clock")
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
return sound_like(num, f"{h} oh {m}")
return sound_like(num, f"{h} {m}", lang_code)
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
return sound_like(num, f"{left} hundred{s}", lang_code)
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
return sound_like(num, f"{left} oh {right}{s}", lang_code)
return sound_like(num, f"{left} {right}{s}", lang_code)

def handle_units(u: re.Match[str]) -> str:
def handle_units(u: re.Match[str], lang_code) -> str:
"""Converts units to their full form"""
unit_string=u.group(6).strip()
unit=unit_string
Expand All @@ -134,14 +148,14 @@ def handle_units(u: re.Match[str]) -> str:

number=u.group(1).strip()
unit[0]=INFLECT_ENGINE.no(unit[0],number)
return " ".join(unit)
return sound_like(u.group(), " ".join(unit), lang_code)

def conditional_int(number: float, threshold: float = 0.00001):
if abs(round(number) - number) < threshold:
return int(round(number))
return number

def handle_money(m: re.Match[str]) -> str:
def handle_money(m: re.Match[str], lang_code) -> str:
"""Convert money expressions to spoken form"""

bill = "dollar" if m.group(2) == "$" else "pound"
Expand All @@ -164,26 +178,26 @@ def handle_money(m: re.Match[str]) -> str:

text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"

return text_number
return sound_like(m.group(), text_number, lang_code)

def handle_decimal(num: re.Match[str]) -> str:
def handle_decimal(num: re.Match[str], lang_code: str) -> str:
"""Convert decimal numbers to spoken form"""
a, b = num.group().split(".")
return " point ".join([a, " ".join(b)])
return sound_like(num.group(), " point ".join([a, " ".join(b)]), lang_code= lang_code)


def handle_email(m: re.Match[str]) -> str:
def handle_email(m: re.Match[str], lang_code: str) -> str:
"""Convert email addresses into speakable format"""
email = m.group(0)
parts = email.split("@")
if len(parts) == 2:
user, domain = parts
domain = domain.replace(".", " dot ")
return f"{user} at {domain}"
return sound_like(email, f"{user} at {domain}", lang_code)
return email


def handle_url(u: re.Match[str]) -> str:
def handle_url(u: re.Match[str], lang_code: str) -> str:
"""Make URLs speakable by converting special characters to spoken words"""
if not u:
return ""
Expand Down Expand Up @@ -227,56 +241,56 @@ def handle_url(u: re.Match[str]) -> str:
url = url.replace("/", " slash ") # Handle any remaining slashes

# Clean up extra spaces
return re.sub(r"\s+", " ", url).strip()
return sound_like(u.group(), re.sub(r"\s+", " ", url).strip(), lang_code)

def handle_phone_number(p: re.Match[str]) -> str:
p=list(p.groups())
def handle_phone_number(p: re.Match[str], lang_code: str) -> str:
g=list(p.groups())

country_code=""
if p[0] is not None:
p[0]=p[0].replace("+","")
country_code += INFLECT_ENGINE.number_to_words(p[0])
if g[0] is not None:
g[0]=g[0].replace("+","")
country_code += INFLECT_ENGINE.number_to_words(g[0])

area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
area_code=INFLECT_ENGINE.number_to_words(g[2].replace("(","").replace(")",""),group=1,comma="")

telephone_prefix=INFLECT_ENGINE.number_to_words(p[3],group=1,comma="")
telephone_prefix=INFLECT_ENGINE.number_to_words(g[3],group=1,comma="")

line_number=INFLECT_ENGINE.number_to_words(p[4],group=1,comma="")
line_number=INFLECT_ENGINE.number_to_words(g[4],group=1,comma="")

return ",".join([country_code,area_code,telephone_prefix,line_number])
return sound_like(p.group(), ",".join([country_code,area_code,telephone_prefix,line_number]), lang_code)

def handle_time(t: re.Match[str]) -> str:
t=t.groups()
def handle_time(t: re.Match[str], lang_code: str) -> str:
g = t.groups()

numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in g[0].split(":")])

half=""
if t[2] is not None:
half=t[2].strip()
if g[2] is not None:
half=g[2].strip()

return numbers + half
return sound_like(t.group(), numbers + half, lang_code)

def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
def normalize_text(text: str,normalization_options: NormalizationOptions, lang_code = "a") -> str:
"""Normalize text for TTS processing"""
# Handle email addresses first if enabled
if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
text = EMAIL_PATTERN.sub(lambda g: handle_email(g, lang_code = lang_code), text)

# Handle URLs if enabled
if normalization_options.url_normalization:
text = URL_PATTERN.sub(handle_url, text)
text = URL_PATTERN.sub(lambda g: handle_url(g, lang_code = lang_code), text)

# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
text=UNIT_PATTERN.sub(handle_units,text)
text=UNIT_PATTERN.sub(lambda g: handle_units(g, lang_code = lang_code),text)

# Replace optional pluralization
if normalization_options.optional_pluralization_normalization:
text = re.sub(r"\(s\)","s",text)

# Replace phone numbers:
if normalization_options.phone_normalization:
text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",handle_phone_number,text)
text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",lambda g: handle_phone_number(g, lang_code = lang_code),text)

# Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
Expand All @@ -288,7 +302,7 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
text = text.replace(a, b + " ")

# Handle simple time in the format of HH:MM:SS
text = TIME_PATTERN.sub(handle_time, text, )
text = TIME_PATTERN.sub(lambda g: handle_time(g, lang_code = lang_code), text, )

# Clean up whitespace
text = re.sub(r"[^\S \n]", " ", text)
Expand All @@ -310,15 +324,15 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str

text = re.sub(
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
handle_money,
lambda g: handle_money(g, lang_code = lang_code),
text,
)

text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", lambda g: split_num(g, lang_code = lang_code), text
)

text = re.sub(r"\d*\.\d+", handle_decimal, text)
text = re.sub(r"\d*\.\d+", lambda g: handle_decimal(g, lang_code = lang_code), text)

# Handle various formatting
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
Expand Down
11 changes: 5 additions & 6 deletions api/src/services/text_processing/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language)


def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a") -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0
Expand All @@ -109,7 +109,7 @@ def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[T
continue

full = sentence + punct
tokens = process_text_chunk(full)
tokens = process_text_chunk(full, language = lang_code)
results.append((full, tokens, len(tokens)))

return results
Expand All @@ -134,15 +134,14 @@ async def smart_split(

# Normalize text
if settings.advanced_text_normalization and normalization_options.normalize:
print(lang_code)
if lang_code in ["a","b","en-us","en-gb"]:
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
text=normalize_text(text,normalization_options)
text = normalize_text(text,normalization_options, lang_code= lang_code)
else:
logger.info("Skipping text normalization as it is only supported for english")

# Process all sentences
sentences = get_sentence_info(text, custom_phoneme_list)
sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code)

current_chunk = []
current_tokens = []
Expand Down Expand Up @@ -178,7 +177,7 @@ async def smart_split(

full_clause = clause + comma

tokens = process_text_chunk(full_clause)
tokens = process_text_chunk(full_clause, language = lang_code)
count = len(tokens)

# If adding clause keeps us under max and not optimal yet
Expand Down
Loading