diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
new file mode 100644
index 00000000..d2c606ba
--- /dev/null
+++ b/.github/workflows/build-push.yml
@@ -0,0 +1,87 @@
+name: Docker Build and push
+
+on:
+ push:
+ tags: [ 'v*.*.*' ]
+ paths-ignore:
+ - '**.md'
+ - 'docs/**'
+ workflow_dispatch:
+ inputs:
+ version:
+ description: 'Version to build and publish (e.g. v0.2.0)'
+ required: true
+ type: string
+
+jobs:
+ prepare-release:
+ runs-on: ubuntu-latest
+ outputs:
+ version: ${{ steps.get-version.outputs.version }}
+ is_prerelease: ${{ steps.check-prerelease.outputs.is_prerelease }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Get version
+ id: get-version
+ run: |
+ if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
+ echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
+ else
+ echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Check if prerelease
+ id: check-prerelease
+ run: |
+ echo "is_prerelease=${{ contains(steps.get-version.outputs.version, '-pre') }}" >> $GITHUB_OUTPUT
+
+ build-images:
+ needs: prepare-release
+ runs-on: ubuntu-latest
+ permissions:
+ packages: write
+ env:
+ DOCKER_BUILDKIT: 1
+ BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
+ # This environment variable will override the VERSION variable in your HCL file.
+ VERSION: ${{ needs.prepare-release.outputs.version }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Free disk space
+ run: |
+ echo "Listing current disk space"
+ df -h
+ echo "Cleaning up disk space..."
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/ghc
+ sudo rm -rf /opt/hostedtoolcache
+ docker system prune -af
+ echo "Disk space after cleanup"
+ df -h
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v2
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+ with:
+ driver-opts: |
+ image=moby/buildkit:latest
+ network=host
+
+ - name: Log in to GitHub Container Registry
+ uses: docker/login-action@v2
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build and push images
+ run: |
+ # No need to override VERSION via --set; the env var does the job.
+ docker buildx bake --push
\ No newline at end of file
diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml
new file mode 100644
index 00000000..d3e880e0
--- /dev/null
+++ b/.github/workflows/docker-publish.yml
@@ -0,0 +1,102 @@
+name: Docker Build and Publish
+
+on:
+ push:
+ tags: [ 'v*.*.*' ]
+ paths-ignore:
+ - '**.md'
+ - 'docs/**'
+ workflow_dispatch:
+ inputs:
+ version:
+ description: 'Version to release (e.g. v0.2.0)'
+ required: true
+ type: string
+
+jobs:
+ prepare-release:
+ runs-on: ubuntu-latest
+ outputs:
+ version: ${{ steps.get-version.outputs.version }}
+ is_prerelease: ${{ steps.check-prerelease.outputs.is_prerelease }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Get version
+ id: get-version
+ run: |
+ if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
+ echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
+ else
+ echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Check if prerelease
+ id: check-prerelease
+ run: echo "is_prerelease=${{ contains(steps.get-version.outputs.version, '-pre') }}" >> $GITHUB_OUTPUT
+
+ build-images:
+ needs: prepare-release
+ runs-on: ubuntu-latest
+ permissions:
+ packages: write
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Free disk space
+ run: |
+ echo "Listing current disk space"
+ df -h
+ echo "Cleaning up disk space..."
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/ghc
+ sudo rm -rf /opt/hostedtoolcache
+ docker system prune -af
+ echo "Disk space after cleanup"
+ df -h
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v2
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+ with:
+ driver-opts: |
+ image=moby/buildkit:latest
+ network=host
+
+ - name: Log in to GitHub Container Registry
+ uses: docker/login-action@v2
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build and push images
+ env:
+ DOCKER_BUILDKIT: 1
+ BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
+ VERSION: ${{ needs.prepare-release.outputs.version }}
+ run: docker buildx bake --push
+
+ create-release:
+ needs: [prepare-release, build-images]
+ runs-on: ubuntu-latest
+ permissions:
+ contents: write
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Create Release
+ uses: softprops/action-gh-release@v1
+ with:
+ tag_name: ${{ needs.prepare-release.outputs.version }}
+ generate_release_notes: true
+ draft: true
+ prerelease: ${{ needs.prepare-release.outputs.is_prerelease }}
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
deleted file mode 100644
index 395d218e..00000000
--- a/.github/workflows/release.yml
+++ /dev/null
@@ -1,110 +0,0 @@
-name: Create Release and Publish Docker Images
-
-on:
- push:
- branches:
- - release # Trigger when commits are pushed to the release branch (e.g., after merging master)
- paths-ignore:
- - '**.md'
- - 'docs/**'
-
-jobs:
- prepare-release:
- runs-on: ubuntu-latest
- outputs:
- version: ${{ steps.get-version.outputs.version }}
- version_tag: ${{ steps.get-version.outputs.version_tag }}
- steps:
- - name: Checkout repository
- uses: actions/checkout@v4
-
- - name: Get version from VERSION file
- id: get-version
- run: |
- VERSION_PLAIN=$(cat VERSION)
- echo "version=${VERSION_PLAIN}" >> $GITHUB_OUTPUT
- echo "version_tag=v${VERSION_PLAIN}" >> $GITHUB_OUTPUT # Add 'v' prefix for tag
-
- build-images:
- needs: prepare-release
- runs-on: ubuntu-latest
- permissions:
- packages: write # Needed to push images to GHCR
- env:
- DOCKER_BUILDKIT: 1
- BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
- # This environment variable will override the VERSION variable in docker-bake.hcl
- VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
- steps:
- - name: Checkout repository
- uses: actions/checkout@v4
- with:
- fetch-depth: 0 # Needed to check for existing tags
-
- - name: Check if tag already exists
- run: |
- TAG_NAME="${{ needs.prepare-release.outputs.version_tag }}"
- echo "Checking for existing tag: $TAG_NAME"
- # Fetch tags explicitly just in case checkout didn't get them all
- git fetch --tags
- if git rev-parse "$TAG_NAME" >/dev/null 2>&1; then
- echo "::error::Tag $TAG_NAME already exists. Please increment the version in the VERSION file."
- exit 1
- else
- echo "Tag $TAG_NAME does not exist. Proceeding with release."
- fi
-
- - name: Free disk space # Optional: Keep as needed for large builds
- run: |
- echo "Listing current disk space"
- df -h
- echo "Cleaning up disk space..."
- sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
- docker system prune -af
- echo "Disk space after cleanup"
- df -h
-
- - name: Set up QEMU
- uses: docker/setup-qemu-action@v3 # Use v3
-
- - name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v3 # Use v3
- with:
- driver-opts: |
- image=moby/buildkit:latest
- network=host
-
- - name: Log in to GitHub Container Registry
- uses: docker/login-action@v3 # Use v3
- with:
- registry: ghcr.io
- username: ${{ github.actor }}
- password: ${{ secrets.GITHUB_TOKEN }}
-
- - name: Build and push images using Docker Bake
- run: |
- echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
- # The VERSION env var above sets the tag for the bake file targets
- docker buildx bake --push
-
- create-release:
- needs: [prepare-release, build-images]
- runs-on: ubuntu-latest
- permissions:
- contents: write # Needed to create releases
- steps:
- - name: Checkout repository
- uses: actions/checkout@v4
- with:
- fetch-depth: 0 # Fetch all history for release notes generation
-
- - name: Create GitHub Release
- uses: softprops/action-gh-release@v2 # Use v2
- with:
- tag_name: ${{ needs.prepare-release.outputs.version_tag }} # Use vX.Y.Z tag
- name: Release ${{ needs.prepare-release.outputs.version_tag }}
- generate_release_notes: true # Auto-generate release notes
- draft: false # Publish immediately
- prerelease: false # Mark as a stable release
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/sync-develop.yml b/.github/workflows/sync-develop.yml
new file mode 100644
index 00000000..56b881f0
--- /dev/null
+++ b/.github/workflows/sync-develop.yml
@@ -0,0 +1,55 @@
+# name: Sync develop with master
+
+# on:
+# push:
+# branches:
+# - master
+
+# jobs:
+# sync-develop:
+# runs-on: ubuntu-latest
+# permissions:
+# contents: write
+# issues: write
+# steps:
+# - name: Checkout repository
+# uses: actions/checkout@v4
+# with:
+# fetch-depth: 0
+# ref: develop
+
+# - name: Configure Git
+# run: |
+# git config user.name "GitHub Actions"
+# git config user.email "actions@github.com"
+
+# - name: Merge master into develop
+# run: |
+# git fetch origin master:master
+# git merge --no-ff origin/master -m "chore: Merge master into develop branch"
+
+# - name: Push changes
+# run: |
+# if ! git push origin develop; then
+# echo "Failed to push to develop branch"
+# exit 1
+# fi
+
+# - name: Handle Failure
+# if: failure()
+# uses: actions/github-script@v7
+# with:
+# script: |
+# const issueBody = `Automatic merge from master to develop failed.
+
+# Please resolve this manually
+
+# Workflow run: ${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`;
+
+# await github.rest.issues.create({
+# owner: context.repo.owner,
+# repo: context.repo.repo,
+# title: '🔄 Automatic master to develop merge failed',
+# body: issueBody,
+# labels: ['merge-failed', 'automation']
+# });
diff --git a/.gitignore b/.gitignore
index 35fa9fb8..3a439dbf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -70,6 +70,3 @@ examples/speech.mp3
examples/phoneme_examples/output/*.wav
examples/assorted_checks/benchmarks/output_audio/*
uv.lock
-
-# Mac MPS virtualenv for dual testing
-.venv-mps
diff --git a/CHANGELOG.md b/CHANGELOG.md
index c076c282..74d60942 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,37 +2,6 @@
Notable changes to this project will be documented in this file.
-## [v0.3.0] - 2025-04-04
-### Added
-- Apple Silicon (MPS) acceleration support for macOS users.
-- Voice subtraction capability for creating unique voice effects.
-- Windows PowerShell start scripts (`start-cpu.ps1`, `start-gpu.ps1`).
-- Automatic model downloading integrated into all start scripts.
-- Example Helm chart values for Azure AKS and Nvidia GPU Operator deployments.
-- `CONTRIBUTING.md` guidelines for developers.
-
-### Changed
-- Version bump of underlying Kokoro and Misaki libraries
-- Default API port reverted to 8880.
-- Docker containers now run as a non-root user for enhanced security.
-- Improved text normalization for numbers, currency, and time formats.
-- Updated and improved Helm chart configurations and documentation.
-- Enhanced temporary file management with better error tracking.
-- Web UI dependencies (Siriwave) are now served locally.
-- Standardized environment variable handling across shell/PowerShell scripts.
-
-### Fixed
-- Corrected an issue preventing download links from being returned when `streaming=false`.
-- Resolved errors in Windows PowerShell scripts related to virtual environment activation order.
-- Addressed potential segfaults during inference.
-- Fixed various Helm chart issues related to health checks, ingress, and default values.
-- Corrected audio quality degradation caused by incorrect bitrate settings in some cases.
-- Ensured custom phonemes provided in input text are preserved.
-- Fixed a 'MediaSource' error affecting playback stability in the web player.
-
-### Removed
-- Obsolete GitHub Actions build workflow, build and publish now occurs on merge to `Release` branch
-
## [v0.2.0post1] - 2025-02-07
- Fix: Building Kokoro from source with adjustments, to avoid CUDA lock
- Fixed ARM64 compatibility on Spacy dep to avoid emulation slowdown
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
deleted file mode 100644
index e6bac74d..00000000
--- a/CONTRIBUTING.md
+++ /dev/null
@@ -1,86 +0,0 @@
-# Contributing to Kokoro-FastAPI
-
-Always appreciate community involvement in making this project better.
-
-## Development Setup
-
-We use `uv` for managing Python environments and dependencies, and `ruff` for linting and formatting.
-
-1. **Clone the repository:**
- ```bash
- git clone https://github.com/remsky/Kokoro-FastAPI.git
- cd Kokoro-FastAPI
- ```
-
-2. **Install `uv`:**
- Follow the instructions on the [official `uv` documentation](https://docs.astral.sh/uv/install/).
-
-3. **Create a virtual environment and install dependencies:**
- It's recommended to use a virtual environment. `uv` can create one for you. Install the base dependencies along with the `test` and `cpu` extras (needed for running tests locally).
- ```bash
- # Create and activate a virtual environment (e.g., named .venv)
- uv venv
- source .venv/bin/activate # On Linux/macOS
- # .venv\Scripts\activate # On Windows
-
- # Install dependencies including test requirements
- uv pip install -e ".[test,cpu]"
- ```
- *Note: If you have an NVIDIA GPU and want to test GPU-specific features locally, you can install `.[test,gpu]` instead, ensuring you have the correct CUDA toolkit installed.*
-
- *Note: If running via uv locally, you will have to install espeak and handle any pathing issues that arise. The Docker images handle this automatically*
-
-4. **Install `ruff` (if not already installed globally):**
- While `ruff` might be included via dependencies, installing it explicitly ensures you have it available.
- ```bash
- uv pip install ruff
- ```
-
-## Running Tests
-
-Before submitting changes, please ensure all tests pass as this is a automated requirement. The tests are run using `pytest`.
-```bash
-# Make sure your virtual environment is activated
-uv run pytest
-```
-*Note: The CI workflow runs tests using `uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing`. Running `uv run pytest` locally should cover the essential checks.*
-
-## Testing with Docker Compose
-
-In addition to local `pytest` runs, test your changes using Docker Compose to ensure they work correctly within the containerized environment. If you aren't able to test on CUDA hardware, make note so it can be tested by another maintainer
-
-```bash
-
-docker compose -f docker/cpu/docker-compose.yml up --build
-+
-docker compose -f docker/gpu/docker-compose.yml up --build
-```
-This command will build the Docker images (if they've changed) and start the services defined in the respective compose file. Verify the application starts correctly and test the relevant functionality.
-
-## Code Formatting and Linting
-
-We use `ruff` to maintain code quality and consistency. Please format and lint your code before committing.
-
-1. **Format the code:**
- ```bash
- # Make sure your virtual environment is activated
- ruff format .
- ```
-
-2. **Lint the code (and apply automatic fixes):**
- ```bash
- # Make sure your virtual environment is activated
- ruff check . --fix
- ```
- Review any changes made by `--fix` and address any remaining linting errors manually.
-
-## Submitting Changes
-
-0. Clone the repo
-1. Create a new branch for your feature or bug fix.
-2. Make your changes, following setup, testing, and formatting guidelines above.
-3. Please try to keep your changes inline with the current design, and modular. Large-scale changes will take longer to review and integrate, and have less chance of being approved outright.
-4. Push your branch to your fork.
-5. Open a Pull Request against the `master` branch of the main repository.
-
-Thank you for contributing!
diff --git a/README.md b/README.md
index 5e9da1d9..9b6d3f7c 100644
--- a/README.md
+++ b/README.md
@@ -3,17 +3,17 @@
# _`FastKoko`_
-[]()
-[]()
+[]()
+[]()
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
-[](https://github.com/hexgrad/kokoro)
-[](https://github.com/hexgrad/misaki)
+[](https://github.com/hexgrad/kokoro)
+[](https://github.com/hexgrad/misaki)
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
-- Multi-language support (English, Japanese, Chinese, _Vietnamese soon_)
+- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_)
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web
@@ -24,6 +24,10 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
### Integration Guides
[](https://github.com/remsky/Kokoro-FastAPI/wiki/Setup-Kubernetes) [](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-DigitalOcean) [](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-SillyTavern)
[](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-OpenWebUi)
+
+
+
+
## Get Started
@@ -34,12 +38,11 @@ Pre built images are available to run, with arm/multi-arch support, and baked in
Refer to the core/config.py file for a full list of variables which can be managed via the environment
```bash
-# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability.
- Named versions should be pinned for your regular usage.
- Feedback/testing is always welcome
+# the `latest` tag can be used, but should not be considered stable as it may include `nightly` branch builds
+# it may have some bonus features however, and feedback/testing is welcome
-docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or:
-docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU
+docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.2.2 # CPU, or:
+docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.2 #NVIDIA GPU
```
@@ -60,11 +63,6 @@ docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NV
# or cd docker/cpu # For CPU support
docker compose up --build
- # *Note for Apple Silicon (M1/M2) users:
- # The current GPU build relies on CUDA, which is not supported on Apple Silicon.
- # If you are on an M1/M2/M3 Mac, please use the `docker/cpu` setup.
- # MPS (Apple's GPU acceleration) support is planned but not yet available.
-
# Models will auto-download, but if needed you can manually download:
python docker/scripts/download_model.py --output api/src/models/v1_0
@@ -88,19 +86,11 @@ docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NV
Run the [model download script](https://github.com/remsky/Kokoro-FastAPI/blob/master/docker/scripts/download_model.py) if you haven't already
Start directly via UV (with hot-reload)
-
- Linux and macOS
```bash
./start-cpu.sh OR
./start-gpu.sh
```
- Windows
- ```powershell
- .\start-cpu.ps1 OR
- .\start-gpu.ps1
- ```
-
@@ -136,8 +126,8 @@ with client.audio.speech.with_streaming_response.create(
-## Features
+## Features
OpenAI-Compatible Speech Endpoint
@@ -513,49 +503,18 @@ Monitor system state and resource usage with these endpoints:
Useful for debugging resource exhaustion or performance issues.
-## Known Issues & Troubleshooting
-
-
-Missing words & Missing some timestamps
-
-The api will automaticly do text normalization on input text which may incorrectly remove or change some phrases. This can be disabled by adding `"normalization_options":{"normalize": false}` to your request json:
-```python
-import requests
-
-response = requests.post(
- "http://localhost:8880/v1/audio/speech",
- json={
- "input": "Hello world!",
- "voice": "af_heart",
- "response_format": "pcm",
- "normalization_options":
- {
- "normalize": False
- }
- },
- stream=True
-)
-
-for chunk in response.iter_content(chunk_size=1024):
- if chunk:
- # Process streaming chunks
- pass
-```
-
-
+## Known Issues
Versioning & Development
-**Branching Strategy:**
-* **`release` branch:** Contains the latest stable build, recommended for production use. Docker images tagged with specific versions (e.g., `v0.3.0`) are built from this branch.
-* **`master` branch:** Used for active development. It may contain experimental features, ongoing changes, or fixes not yet in a stable release. Use this branch if you want the absolute latest code, but be aware it might be less stable. The `latest` Docker tag often points to builds from this branch.
-
-Note: This is a *development* focused project at its core.
+I'm doing what I can to keep things stable, but we are on an early and rapid set of build cycles here.
+If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR. Will leave the branch up here for the last known stable points:
-If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR.
+`v0.1.4`
+`v0.0.5post1`
-Free and open source is a community effort, and there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
+Free and open source is a community effort, and I love working on this project, though there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
str:
- """Get the appropriate device based on settings and availability"""
- if not self.use_gpu:
- return "cpu"
-
- if self.device_type:
- return self.device_type
-
- # Auto-detect device
- if torch.backends.mps.is_available():
- return "mps"
- elif torch.cuda.is_available():
- return "cuda"
- return "cpu"
-
settings = Settings()
diff --git a/api/src/core/paths.py b/api/src/core/paths.py
index 771b70c3..0e605284 100644
--- a/api/src/core/paths.py
+++ b/api/src/core/paths.py
@@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str:
)
# Construct web directory path relative to project root
- web_dir = os.path.join(root_dir, settings.web_player_path)
+ web_dir = os.path.join("/app", settings.web_player_path)
# Search in web directory
search_paths = [web_dir]
diff --git a/api/src/inference/base.py b/api/src/inference/base.py
index e25c2b51..6b59fa98 100644
--- a/api/src/inference/base.py
+++ b/api/src/inference/base.py
@@ -1,41 +1,34 @@
"""Base interface for Kokoro inference."""
from abc import ABC, abstractmethod
-from typing import AsyncGenerator, List, Optional, Tuple, Union
+from typing import AsyncGenerator, Optional, Tuple, Union, List
import numpy as np
import torch
-
class AudioChunk:
"""Class for audio chunks returned by model backends"""
-
- def __init__(
- self,
- audio: np.ndarray,
- word_timestamps: Optional[List] = [],
- output: Optional[Union[bytes, np.ndarray]] = b"",
- ):
- self.audio = audio
- self.word_timestamps = word_timestamps
- self.output = output
-
+
+ def __init__(self,
+ audio: np.ndarray,
+ word_timestamps: Optional[List]=[],
+ output: Optional[Union[bytes,np.ndarray]]=b""
+ ):
+ self.audio=audio
+ self.word_timestamps=word_timestamps
+ self.output=output
+
@staticmethod
def combine(audio_chunk_list: List):
- output = AudioChunk(
- audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
- )
-
+ output=AudioChunk(audio_chunk_list[0].audio,audio_chunk_list[0].word_timestamps)
+
for audio_chunk in audio_chunk_list[1:]:
- output.audio = np.concatenate(
- (output.audio, audio_chunk.audio), dtype=np.int16
- )
+ output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
if output.word_timestamps is not None:
- output.word_timestamps += audio_chunk.word_timestamps
-
+ output.word_timestamps+=audio_chunk.word_timestamps
+
return output
-
-
+
class ModelBackend(ABC):
"""Abstract base class for model inference backend."""
diff --git a/api/src/inference/kokoro_v1.py b/api/src/inference/kokoro_v1.py
index a627dbb3..419ade7e 100644
--- a/api/src/inference/kokoro_v1.py
+++ b/api/src/inference/kokoro_v1.py
@@ -11,10 +11,9 @@
from ..core import paths
from ..core.config import settings
from ..core.model_config import model_config
+from .base import BaseModelBackend
+from .base import AudioChunk
from ..structures.schemas import WordTimestamp
-from .base import AudioChunk, BaseModelBackend
-
-
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
@@ -22,7 +21,7 @@ def __init__(self):
"""Initialize backend with environment-based configuration."""
super().__init__()
# Strictly respect settings.use_gpu
- self._device = settings.get_device()
+ self._device = "cuda" if settings.use_gpu else "cpu"
self._model: Optional[KModel] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
@@ -49,16 +48,9 @@ async def load_model(self, path: str) -> None:
# Load model and let KModel handle device mapping
self._model = KModel(config=config_path, model=model_path).eval()
- # For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
- if self._device == "mps":
- logger.info(
- "Moving model to MPS device with CPU fallback for unsupported operations"
- )
- self._model = self._model.to(torch.device("mps"))
- elif self._device == "cuda":
+ # Move to CUDA if needed
+ if self._device == "cuda":
self._model = self._model.cuda()
- else:
- self._model = self._model.cpu()
except FileNotFoundError as e:
raise e
@@ -148,11 +140,11 @@ async def generate_from_tokens(
voice_path = temp_path
# Use provided lang_code, settings voice code override, or first letter of voice name
- if lang_code: # api is given priority
+ if lang_code: # api is given priority
pipeline_lang_code = lang_code
- elif settings.default_voice_code: # settings is next priority
+ elif settings.default_voice_code: # settings is next priority
pipeline_lang_code = settings.default_voice_code
- else: # voice name is default/fallback
+ else: # voice name is default/fallback
pipeline_lang_code = voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
@@ -247,15 +239,7 @@ async def generate(
voice_path = temp_path
# Use provided lang_code, settings voice code override, or first letter of voice name
- pipeline_lang_code = (
- lang_code
- if lang_code
- else (
- settings.default_voice_code
- if settings.default_voice_code
- else voice_name[0].lower()
- )
- )
+ pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else voice_name[0].lower())
pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(
@@ -266,19 +250,20 @@ async def generate(
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
- word_timestamps = None
- if (
- return_timestamps
- and hasattr(result, "tokens")
- and result.tokens
- ):
- word_timestamps = []
- current_offset = 0.0
+ word_timestamps=None
+ if return_timestamps and hasattr(result, "tokens") and result.tokens:
+ word_timestamps=[]
+ current_offset=0.0
logger.debug(
- f"Processing chunk timestamps with {len(result.tokens)} tokens"
- )
+ f"Processing chunk timestamps with {len(result.tokens)} tokens"
+ )
if result.pred_dur is not None:
try:
+ # Join timestamps for this chunk's tokens
+ KPipeline.join_timestamps(
+ result.tokens, result.pred_dur
+ )
+
# Add timestamps with offset
for token in result.tokens:
if not all(
@@ -292,14 +277,14 @@ async def generate(
continue
if not token.text or not token.text.strip():
continue
-
+
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
WordTimestamp(
word=str(token.text).strip(),
start_time=start_time,
- end_time=end_time,
+ end_time=end_time
)
)
logger.debug(
@@ -310,10 +295,9 @@ async def generate(
logger.error(
f"Failed to process timestamps for chunk: {e}"
)
-
- yield AudioChunk(
- result.audio.numpy(), word_timestamps=word_timestamps
- )
+
+
+ yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
else:
logger.warning("No audio in chunk")
@@ -334,7 +318,6 @@ def _check_memory(self) -> bool:
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
- # MPS doesn't provide memory management APIs
return False
def _clear_memory(self) -> None:
@@ -342,10 +325,6 @@ def _clear_memory(self) -> None:
if self._device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
- elif self._device == "mps":
- # Empty cache if available (future-proofing)
- if hasattr(torch.mps, "empty_cache"):
- torch.mps.empty_cache()
def unload(self) -> None:
"""Unload model and free resources."""
diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py
index eb817ecb..9cef95ff 100644
--- a/api/src/inference/model_manager.py
+++ b/api/src/inference/model_manager.py
@@ -141,8 +141,6 @@ async def generate(self, *args, **kwargs):
try:
async for chunk in self._backend.generate(*args, **kwargs):
- if settings.default_volume_multiplier != 1.0:
- chunk.audio *= settings.default_volume_multiplier
yield chunk
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
diff --git a/api/src/inference/voice_manager.py b/api/src/inference/voice_manager.py
index 0d82c4f7..5466fa95 100644
--- a/api/src/inference/voice_manager.py
+++ b/api/src/inference/voice_manager.py
@@ -19,7 +19,7 @@ class VoiceManager:
def __init__(self):
"""Initialize voice manager."""
# Strictly respect settings.use_gpu
- self._device = settings.get_device()
+ self._device = "cuda" if settings.use_gpu else "cpu"
self._voices: Dict[str, torch.Tensor] = {}
async def get_voice_path(self, voice_name: str) -> str:
diff --git a/api/src/main.py b/api/src/main.py
index 23299cf8..5aba8551 100644
--- a/api/src/main.py
+++ b/api/src/main.py
@@ -9,8 +9,9 @@
import torch
import uvicorn
-from fastapi import FastAPI
+from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
from loguru import logger
from .core.config import settings
@@ -43,6 +44,27 @@ def setup_logger():
# Configure logger
setup_logger()
+security = HTTPBasic()
+
+def get_http_credentials(credentials: HTTPBasicCredentials = Depends(security)):
+ """Conditionally verify HTTP Basic Auth credentials"""
+ username = os.getenv("HTTP_USERNAME")
+ password = os.getenv("HTTP_PASSWORD")
+
+ # Skip authentication if credentials not configured
+ if not username or not password:
+ return
+
+ # Perform authentication check if credentials are configured
+ if (credentials.username != username or credentials.password != password):
+ raise HTTPException(
+ status_code=401,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Basic"},
+ )
+ return credentials.username
+
+
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -85,12 +107,7 @@ async def lifespan(app: FastAPI):
{boundary}
"""
startup_msg += f"\nModel warmed up on {device}: {model}"
- if device == "mps":
- startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
- elif device == "cuda":
- startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
- else:
- startup_msg += "\nRunning on CPU"
+ startup_msg += f"CUDA: {torch.cuda.is_available()}"
startup_msg += f"\n{voicepack_count} voice packs loaded"
# Add web player info if enabled
@@ -128,11 +145,11 @@ async def lifespan(app: FastAPI):
)
# Include routers
-app.include_router(openai_router, prefix="/v1")
-app.include_router(dev_router) # Development endpoints
-app.include_router(debug_router) # Debug endpoints
+app.include_router(openai_router, prefix="/v1", dependencies=[Depends(get_http_credentials)])
+app.include_router(dev_router, dependencies=[Depends(get_http_credentials)]) # Development endpoints
+app.include_router(debug_router, dependencies=[Depends(get_http_credentials)]) # Debug endpoints
if settings.enable_web_player:
- app.include_router(web_router, prefix="/web") # Web player static files
+ app.include_router(web_router, prefix="/web", dependencies=[Depends(get_http_credentials)]) # Web player static files
# Health check endpoint
diff --git a/api/src/routers/debug.py b/api/src/routers/debug.py
index 8acb9fd7..6a65362c 100644
--- a/api/src/routers/debug.py
+++ b/api/src/routers/debug.py
@@ -3,7 +3,6 @@
from datetime import datetime
import psutil
-import torch
from fastapi import APIRouter
try:
@@ -114,14 +113,7 @@ async def get_system_info():
# GPU Info if available
gpu_info = None
- if torch.backends.mps.is_available():
- gpu_info = {
- "type": "MPS",
- "available": True,
- "device": "Apple Silicon",
- "backend": "Metal",
- }
- elif GPU_AVAILABLE:
+ if GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
gpu_info = [
diff --git a/api/src/routers/development.py b/api/src/routers/development.py
index 8c8ed7e1..7fbcc566 100644
--- a/api/src/routers/development.py
+++ b/api/src/routers/development.py
@@ -1,24 +1,20 @@
-import base64
-import json
-import os
import re
-from pathlib import Path
-from typing import AsyncGenerator, List, Tuple, Union
+from typing import List, Union, AsyncGenerator, Tuple
import numpy as np
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
-from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
+from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from kokoro import KPipeline
from loguru import logger
-from ..core.config import settings
from ..inference.base import AudioChunk
+from ..core.config import settings
from ..services.audio import AudioNormalizer, AudioService
from ..services.streaming_audio_writer import StreamingAudioWriter
-from ..services.temp_manager import TempFileWriter
from ..services.text_processing import smart_split
from ..services.tts_service import TTSService
+from ..services.temp_manager import TempFileWriter
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
from ..structures.custom_responses import JSONStreamingResponse
from ..structures.text_schemas import (
@@ -26,7 +22,12 @@
PhonemeRequest,
PhonemeResponse,
)
-from .openai_compatible import process_and_validate_voices, stream_audio_chunks
+from .openai_compatible import process_voices, stream_audio_chunks
+import json
+import os
+import base64
+from pathlib import Path
+
router = APIRouter(tags=["text processing"])
@@ -104,7 +105,7 @@ async def generate_chunks():
if chunk_audio is not None:
# Normalize audio before writing
- normalized_audio = normalizer.normalize(chunk_audio)
+ normalized_audio = await normalizer.normalize(chunk_audio)
# Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
@@ -114,14 +115,13 @@ async def generate_chunks():
final_bytes = writer.write_chunk(finalize=True)
if final_bytes:
yield final_bytes
- writer.close()
else:
raise ValueError("Failed to generate audio data")
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
# Clean up writer on error
- writer.close()
+ writer.write_chunk(finalize=True)
# Re-raise the original exception
raise
@@ -157,7 +157,6 @@ async def generate_chunks():
},
)
-
@router.post("/dev/captioned_speech")
async def create_captioned_speech(
request: CaptionedSpeechRequest,
@@ -170,7 +169,7 @@ async def create_captioned_speech(
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
- voice_name = await process_and_validate_voices(request.voice, tts_service)
+ voice_name = await process_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
@@ -182,13 +181,10 @@ async def create_captioned_speech(
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
- writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
- generator = stream_audio_chunks(
- tts_service, request, client_request, writer
- )
+ generator = stream_audio_chunks(tts_service, request, client_request)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
@@ -215,35 +211,21 @@ async def dual_output():
# Write chunks to temp file and stream
async for chunk_data in generator:
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
- timestamp_acumulator = []
-
+ timestamp_acumulator=[]
+
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
- base64_chunk = base64.b64encode(
- chunk_data.output
- ).decode("utf-8")
-
+ base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
+
# Add any chunks that may be in the acumulator into the return word_timestamps
- if chunk_data.word_timestamps is not None:
- chunk_data.word_timestamps = (
- timestamp_acumulator + chunk_data.word_timestamps
- )
- timestamp_acumulator = []
- else:
- chunk_data.word_timestamps = []
-
- yield CaptionedSpeechResponse(
- audio=base64_chunk,
- audio_format=content_type,
- timestamps=chunk_data.word_timestamps,
- )
+ chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
+ timestamp_acumulator=[]
+
+ yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
else:
- if (
- chunk_data.word_timestamps is not None
- and len(chunk_data.word_timestamps) > 0
- ):
- timestamp_acumulator += chunk_data.word_timestamps
-
+ if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
+ timestamp_acumulator+=chunk_data.word_timestamps
+
# Finalize the temp file
await temp_writer.finalize()
except Exception as e:
@@ -254,7 +236,6 @@ async def dual_output():
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
- writer.close()
# Stream with temp file writing
return JSONStreamingResponse(
@@ -264,40 +245,25 @@ async def dual_output():
async def single_output():
try:
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
- timestamp_acumulator = []
-
+ timestamp_acumulator=[]
+
# Stream chunks
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
# Encode the chunk bytes into base 64
- base64_chunk = base64.b64encode(chunk_data.output).decode(
- "utf-8"
- )
-
+ base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
+
# Add any chunks that may be in the acumulator into the return word_timestamps
- if chunk_data.word_timestamps is not None:
- chunk_data.word_timestamps = (
- timestamp_acumulator + chunk_data.word_timestamps
- )
- else:
- chunk_data.word_timestamps = []
- timestamp_acumulator = []
-
- yield CaptionedSpeechResponse(
- audio=base64_chunk,
- audio_format=content_type,
- timestamps=chunk_data.word_timestamps,
- )
+ chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
+ timestamp_acumulator=[]
+
+ yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
else:
- if (
- chunk_data.word_timestamps is not None
- and len(chunk_data.word_timestamps) > 0
- ):
- timestamp_acumulator += chunk_data.word_timestamps
-
+ if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
+ timestamp_acumulator+=chunk_data.word_timestamps
+
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
- writer.close()
raise
# Standard streaming without download link
@@ -316,41 +282,34 @@ async def single_output():
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
- writer=writer,
speed=request.speed,
return_timestamps=request.return_timestamps,
- volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)
-
+
audio_data = await AudioService.convert_audio(
audio_data,
+ 24000,
request.response_format,
- writer,
+ is_first_chunk=True,
is_last_chunk=False,
trim_audio=False,
)
-
+
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
+ 24000,
request.response_format,
- writer,
+ is_first_chunk=False,
is_last_chunk=True,
)
- output = audio_data.output + final.output
-
- base64_output = base64.b64encode(output).decode("utf-8")
-
- content = CaptionedSpeechResponse(
- audio=base64_output,
- audio_format=content_type,
- timestamps=audio_data.word_timestamps,
- ).model_dump()
-
- writer.close()
-
+ output=audio_data.output + final.output
+
+ base64_output= base64.b64encode(output).decode("utf-8")
+
+ content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
return JSONResponse(
content=content,
media_type="application/json",
@@ -363,12 +322,6 @@ async def single_output():
except ValueError as e:
# Handle validation errors
logger.warning(f"Invalid request: {str(e)}")
-
- try:
- writer.close()
- except:
- pass
-
raise HTTPException(
status_code=400,
detail={
@@ -380,12 +333,6 @@ async def single_output():
except RuntimeError as e:
# Handle runtime/processing errors
logger.error(f"Processing error: {str(e)}")
-
- try:
- writer.close()
- except:
- pass
-
raise HTTPException(
status_code=500,
detail={
@@ -397,12 +344,6 @@ async def single_output():
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
-
- try:
- writer.close()
- except:
- pass
-
raise HTTPException(
status_code=500,
detail={
@@ -410,4 +351,4 @@ async def single_output():
"message": str(e),
"type": "server_error",
},
- )
+ )
\ No newline at end of file
diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py
index c3252217..c4036d67 100644
--- a/api/src/routers/openai_compatible.py
+++ b/api/src/routers/openai_compatible.py
@@ -5,23 +5,23 @@
import os
import re
import tempfile
-from typing import AsyncGenerator, Dict, List, Tuple, Union
+from typing import AsyncGenerator, Dict, List, Union, Tuple
from urllib import response
+import numpy as np
import aiofiles
-import numpy as np
+
+from structures.schemas import CaptionedSpeechRequest
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
-from ..core.config import settings
from ..inference.base import AudioChunk
+from ..core.config import settings
from ..services.audio import AudioService
-from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.tts_service import TTSService
from ..structures import OpenAISpeechRequest
-from ..structures.schemas import CaptionedSpeechRequest
# Load OpenAI mappings
@@ -80,7 +80,7 @@ def get_model_name(model: str) -> str:
return base_name + ".pth"
-async def process_and_validate_voices(
+async def process_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
"""Process voice input, handling both string and list formats
@@ -88,74 +88,72 @@ async def process_and_validate_voices(
Returns:
Voice name to use (with weights if specified)
"""
- voices = []
# Convert input to list of voices
if isinstance(voice_input, str):
- voice_input = voice_input.replace(" ", "").strip()
-
- if voice_input[-1] in "+-" or voice_input[0] in "+-":
- raise ValueError(f"Voice combination contains empty combine items")
-
- if re.search(r"[+-]{2,}", voice_input) is not None:
- raise ValueError(f"Voice combination contains empty combine items")
- voices = re.split(r"([-+])", voice_input)
+ # Check if it's an OpenAI voice name
+ mapped_voice = _openai_mappings["voices"].get(voice_input)
+ if mapped_voice:
+ voice_input = mapped_voice
+ # Split on + but preserve any parentheses
+ voices = []
+ for part in voice_input.split("+"):
+ part = part.strip()
+ if not part:
+ continue
+ # Extract voice name without weight
+ voice_name = part.split("(")[0].strip()
+ # Check if it's a valid voice
+ available_voices = await tts_service.list_voices()
+ if voice_name not in available_voices:
+ raise ValueError(
+ f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
+ )
+ voices.append(part)
else:
- voices = [[item, "+"] for item in voice_input][:-1]
-
- available_voices = await tts_service.list_voices()
-
- for voice_index in range(0, len(voices), 2):
- mapped_voice = voices[voice_index].split("(")
- mapped_voice = list(map(str.strip, mapped_voice))
-
- if len(mapped_voice) > 2:
- raise ValueError(
- f"Voice '{voices[voice_index]}' contains too many weight items"
- )
-
- if mapped_voice.count(")") > 1:
- raise ValueError(
- f"Voice '{voices[voice_index]}' contains too many weight items"
- )
-
- mapped_voice[0] = _openai_mappings["voices"].get(
- mapped_voice[0], mapped_voice[0]
- )
-
- if mapped_voice[0] not in available_voices:
- raise ValueError(
- f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
- )
+ # For list input, map each voice if it's an OpenAI voice name
+ voices = []
+ for v in voice_input:
+ mapped = _openai_mappings["voices"].get(v, v)
+ voice_name = mapped.split("(")[0].strip()
+ # Check if it's a valid voice
+ available_voices = await tts_service.list_voices()
+ if voice_name not in available_voices:
+ raise ValueError(
+ f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
+ )
+ voices.append(mapped)
- voices[voice_index] = "(".join(mapped_voice)
+ if not voices:
+ raise ValueError("No voices provided")
- return "".join(voices)
+ # For multiple voices, combine them with +
+ return "+".join(voices)
async def stream_audio_chunks(
- tts_service: TTSService,
- request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
- client_request: Request,
- writer: StreamingAudioWriter,
+ tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
- voice_name = await process_and_validate_voices(request.voice, tts_service)
- unique_properties = {"return_timestamps": False}
- if hasattr(request, "return_timestamps"):
- unique_properties["return_timestamps"] = request.return_timestamps
+ voice_name = await process_voices(request.voice, tts_service)
+ unique_properties={
+ "return_timestamps":False
+ }
+ if hasattr(request, "return_timestamps"):
+ unique_properties["return_timestamps"]=request.return_timestamps
+
try:
+ logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk_data in tts_service.generate_audio_stream(
text=request.input,
voice=voice_name,
- writer=writer,
speed=request.speed,
output_format=request.response_format,
- lang_code=request.lang_code,
- volume_multiplier=request.volume_multiplier,
+ lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(),
normalization_options=request.normalization_options,
return_timestamps=unique_properties["return_timestamps"],
):
+
# Check if client is still connected
is_disconnected = client_request.is_disconnected
if callable(is_disconnected):
@@ -173,6 +171,7 @@ async def stream_audio_chunks(
@router.post("/audio/speech")
async def create_speech(
+
request: OpenAISpeechRequest,
client_request: Request,
x_raw_response: str = Header(None, alias="x-raw-response"),
@@ -192,7 +191,7 @@ async def create_speech(
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
- voice_name = await process_and_validate_voices(request.voice, tts_service)
+ voice_name = await process_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
@@ -204,14 +203,10 @@ async def create_speech(
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
- writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
-
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
- generator = stream_audio_chunks(
- tts_service, request, client_request, writer
- )
+ generator = stream_audio_chunks(tts_service, request, client_request)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
@@ -234,10 +229,6 @@ async def create_speech(
"X-Download-Path": download_path,
}
- # Add header to indicate if temp file writing is available
- if temp_writer._write_error:
- headers["X-Download-Status"] = "unavailable"
-
# Create async generator for streaming
async def dual_output():
try:
@@ -245,9 +236,9 @@ async def dual_output():
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
- # if return_json:
+ #if return_json:
# yield chunk, chunk_data
- # else:
+ #else:
yield chunk_data.output
# Finalize the temp file
@@ -260,7 +251,6 @@ async def dual_output():
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
- writer.close()
# Stream with temp file writing
return StreamingResponse(
@@ -275,9 +265,8 @@ async def single_output():
yield chunk_data.output
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
- writer.close()
raise
-
+
# Standard streaming without download link
return StreamingResponse(
single_output(),
@@ -290,83 +279,45 @@ async def single_output():
},
)
else:
- headers = {
- "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
- "Cache-Control": "no-cache", # Prevent caching
- }
-
# Generate complete audio using public interface
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
- writer=writer,
speed=request.speed,
- volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)
audio_data = await AudioService.convert_audio(
audio_data,
+ 24000,
request.response_format,
- writer,
+ is_first_chunk=True,
is_last_chunk=False,
- trim_audio=False,
+ trim_audio=False
)
-
+
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
+ 24000,
request.response_format,
- writer,
+ is_first_chunk=False,
is_last_chunk=True,
)
- output = audio_data.output + final.output
-
- if request.return_download_link:
- from ..services.temp_manager import TempFileWriter
-
- # Use download_format if specified, otherwise use response_format
- output_format = request.download_format or request.response_format
- temp_writer = TempFileWriter(output_format)
- await temp_writer.__aenter__() # Initialize temp file
-
- # Get download path immediately after temp file creation
- download_path = temp_writer.download_path
- headers["X-Download-Path"] = download_path
-
- try:
- # Write chunks to temp file
- logger.info("Writing chunks to tempory file for download")
- await temp_writer.write(output)
- # Finalize the temp file
- await temp_writer.finalize()
-
- except Exception as e:
- logger.error(f"Error in dual output: {e}")
- await temp_writer.__aexit__(type(e), e, e.__traceback__)
- raise
- finally:
- # Ensure temp writer is closed
- if not temp_writer._finalized:
- await temp_writer.__aexit__(None, None, None)
- writer.close()
-
+ output=audio_data.output + final.output
return Response(
content=output,
media_type=content_type,
- headers=headers,
+ headers={
+ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
+ "Cache-Control": "no-cache", # Prevent caching
+ },
)
except ValueError as e:
# Handle validation errors
logger.warning(f"Invalid request: {str(e)}")
-
- try:
- writer.close()
- except:
- pass
-
raise HTTPException(
status_code=400,
detail={
@@ -378,12 +329,6 @@ async def single_output():
except RuntimeError as e:
# Handle runtime/processing errors
logger.error(f"Processing error: {str(e)}")
-
- try:
- writer.close()
- except:
- pass
-
raise HTTPException(
status_code=500,
detail={
@@ -395,12 +340,6 @@ async def single_output():
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error in speech generation: {str(e)}")
-
- try:
- writer.close()
- except:
- pass
-
raise HTTPException(
status_code=500,
detail={
@@ -457,23 +396,26 @@ async def list_models():
"id": "tts-1",
"object": "model",
"created": 1686935002,
- "owned_by": "kokoro",
+ "owned_by": "kokoro"
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
- "owned_by": "kokoro",
+ "owned_by": "kokoro"
},
{
"id": "kokoro",
"object": "model",
"created": 1686935002,
- "owned_by": "kokoro",
- },
+ "owned_by": "kokoro"
+ }
]
-
- return {"object": "list", "data": models}
+
+ return {
+ "object": "list",
+ "data": models
+ }
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(
@@ -485,7 +427,6 @@ async def list_models():
},
)
-
@router.get("/models/{model}")
async def retrieve_model(model: str):
"""Retrieve a specific model"""
@@ -496,22 +437,22 @@ async def retrieve_model(model: str):
"id": "tts-1",
"object": "model",
"created": 1686935002,
- "owned_by": "kokoro",
+ "owned_by": "kokoro"
},
"tts-1-hd": {
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
- "owned_by": "kokoro",
+ "owned_by": "kokoro"
},
"kokoro": {
"id": "kokoro",
"object": "model",
"created": 1686935002,
- "owned_by": "kokoro",
- },
+ "owned_by": "kokoro"
+ }
}
-
+
# Check if requested model exists
if model not in models:
raise HTTPException(
@@ -519,10 +460,10 @@ async def retrieve_model(model: str):
detail={
"error": "model_not_found",
"message": f"Model '{model}' not found",
- "type": "invalid_request_error",
- },
+ "type": "invalid_request_error"
+ }
)
-
+
# Return the specific model
return models[model]
except HTTPException:
@@ -538,7 +479,6 @@ async def retrieve_model(model: str):
},
)
-
@router.get("/audio/voices")
async def list_voices():
"""List all available voices for text-to-speech"""
diff --git a/api/src/services/audio.py b/api/src/services/audio.py
index 6ae6d791..0c752247 100644
--- a/api/src/services/audio.py
+++ b/api/src/services/audio.py
@@ -1,12 +1,12 @@
"""Audio conversion service"""
-import math
import struct
import time
-from io import BytesIO
from typing import Tuple
+from io import BytesIO
import numpy as np
+import math
import scipy.io.wavfile as wavfile
import soundfile as sf
from loguru import logger
@@ -14,9 +14,8 @@
from torch import norm
from ..core.config import settings
-from ..inference.base import AudioChunk
from .streaming_audio_writer import StreamingAudioWriter
-
+from ..inference.base import AudioChunk
class AudioNormalizer:
"""Handles audio normalization state for a single stream"""
@@ -25,78 +24,53 @@ def __init__(self):
self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
- self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
-
- def find_first_last_non_silent(
- self,
- audio_data: np.ndarray,
- chunk_text: str,
- speed: float,
- silence_threshold_db: int = -45,
- is_last_chunk: bool = False,
- ) -> tuple[int, int]:
- """Finds the indices of the first and last non-silent samples in audio data.
+ self.samples_to_pad_start= int(50 * self.sample_rate / 1000)
+ def find_first_last_non_silent(self,audio_data: np.ndarray, chunk_text: str, speed: float, silence_threshold_db: int = -45, is_last_chunk: bool = False) -> tuple[int, int]:
+ """Finds the indices of the first and last non-silent samples in audio data.
+
Args:
audio_data: Input audio data as numpy array
chunk_text: The text sent to the model to generate the resulting speech
speed: The speaking speed of the voice
silence_threshold_db: How quiet audio has to be to be conssidered silent
is_last_chunk: Whether this is the last chunk
-
+
Returns:
A tuple with the start of the non silent portion and with the end of the non silent portion
"""
- pad_multiplier = 1
- split_character = chunk_text.strip()
+ pad_multiplier=1
+ split_character=chunk_text.strip()
if len(split_character) > 0:
- split_character = split_character[-1]
+ split_character=split_character[-1]
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
- pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
- split_character
- ]
+ pad_multiplier=settings.dynamic_gap_trim_padding_char_multiplier[split_character]
if not is_last_chunk:
- samples_to_pad_end = max(
- int(
- (
- settings.dynamic_gap_trim_padding_ms
- * self.sample_rate
- * pad_multiplier
- )
- / 1000
- )
- - self.samples_to_pad_start,
- 0,
- )
+ samples_to_pad_end= max(int((settings.dynamic_gap_trim_padding_ms * self.sample_rate * pad_multiplier) / 1000) - self.samples_to_pad_start, 0)
else:
- samples_to_pad_end = self.samples_to_pad_start
+ samples_to_pad_end=self.samples_to_pad_start
# Convert dBFS threshold to amplitude
- amplitude_threshold = np.iinfo(audio_data.dtype).max * (
- 10 ** (silence_threshold_db / 20)
- )
+ amplitude_threshold = np.iinfo(audio_data.dtype).max * (10 ** (silence_threshold_db / 20))
# Find the first samples above the silence threshold at the start and end of the audio
- non_silent_index_start, non_silent_index_end = None, None
+ non_silent_index_start, non_silent_index_end = None,None
- for X in range(0, len(audio_data)):
- if abs(audio_data[X]) > amplitude_threshold:
- non_silent_index_start = X
+ for X in range(0,len(audio_data)):
+ if audio_data[X] > amplitude_threshold:
+ non_silent_index_start=X
break
-
+
for X in range(len(audio_data) - 1, -1, -1):
- if abs(audio_data[X]) > amplitude_threshold:
- non_silent_index_end = X
+ if audio_data[X] > amplitude_threshold:
+ non_silent_index_end=X
break
# Handle the case where the entire audio is silent
if non_silent_index_start == None or non_silent_index_end == None:
return 0, len(audio_data)
- return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
- non_silent_index_end + math.ceil(samples_to_pad_end / speed),
- len(audio_data),
- )
+ return max(non_silent_index_start - self.samples_to_pad_start,0), min(non_silent_index_end + math.ceil(samples_to_pad_end / speed),len(audio_data))
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Convert audio data to int16 range
@@ -111,7 +85,6 @@ def normalize(self, audio_data: np.ndarray) -> np.ndarray:
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
return audio_data
-
class AudioService:
"""Service for audio format conversions with streaming support"""
@@ -135,27 +108,30 @@ class AudioService:
},
}
+ _writers = {}
+
@staticmethod
async def convert_audio(
audio_chunk: AudioChunk,
+ sample_rate: int,
output_format: str,
- writer: StreamingAudioWriter,
speed: float = 1,
chunk_text: str = "",
+ is_first_chunk: bool = True,
is_last_chunk: bool = False,
trim_audio: bool = True,
normalizer: AudioNormalizer = None,
- ) -> AudioChunk:
+ ) -> Tuple[AudioChunk]:
"""Convert audio data to specified format with streaming support
Args:
audio_data: Numpy array of audio samples
+ sample_rate: Sample rate of the audio
output_format: Target format (wav, mp3, ogg, pcm)
- writer: The StreamingAudioWriter to use
speed: The speaking speed of the voice
chunk_text: The text sent to the model to generate the resulting speech
+ is_first_chunk: Whether this is the first chunk
is_last_chunk: Whether this is the last chunk
- trim_audio: Whether audio should be trimmed
normalizer: Optional AudioNormalizer instance for consistent normalization
Returns:
@@ -170,14 +146,21 @@ async def convert_audio(
# Always normalize audio to ensure proper amplitude scaling
if normalizer is None:
normalizer = AudioNormalizer()
-
+
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
-
+
if trim_audio == True:
- audio_chunk = AudioService.trim_audio(
- audio_chunk, chunk_text, speed, is_last_chunk, normalizer
+ audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
+
+ # Get or create format-specific writer
+ writer_key = f"{output_format}_{sample_rate}"
+ if is_first_chunk or writer_key not in AudioService._writers:
+ AudioService._writers[writer_key] = StreamingAudioWriter(
+ output_format, sample_rate
)
-
+
+ writer = AudioService._writers[writer_key]
+
# Write audio data first
if len(audio_chunk.audio) > 0:
chunk_data = writer.write_chunk(audio_chunk.audio)
@@ -185,13 +168,13 @@ async def convert_audio(
# Then finalize if this is the last chunk
if is_last_chunk:
final_data = writer.write_chunk(finalize=True)
-
+ del AudioService._writers[writer_key]
if final_data:
- audio_chunk.output = final_data
+ audio_chunk.output=final_data
return audio_chunk
-
+
if chunk_data:
- audio_chunk.output = chunk_data
+ audio_chunk.output=chunk_data
return audio_chunk
except Exception as e:
@@ -199,15 +182,8 @@ async def convert_audio(
raise ValueError(
f"Failed to convert audio stream to {output_format}: {str(e)}"
)
-
@staticmethod
- def trim_audio(
- audio_chunk: AudioChunk,
- chunk_text: str = "",
- speed: float = 1,
- is_last_chunk: bool = False,
- normalizer: AudioNormalizer = None,
- ) -> AudioChunk:
+ def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
"""Trim silence from start and end
Args:
@@ -216,33 +192,30 @@ def trim_audio(
speed: The speaking speed of the voice
is_last_chunk: Whether this is the last chunk
normalizer: Optional AudioNormalizer instance for consistent normalization
-
+
Returns:
Trimmed audio data
"""
if normalizer is None:
normalizer = AudioNormalizer()
-
- audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
-
- trimed_samples = 0
+
+ audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
+
+ trimed_samples=0
# Trim start and end if enough samples
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
- audio_chunk.audio = audio_chunk.audio[
- normalizer.samples_to_trim : -normalizer.samples_to_trim
- ]
- trimed_samples += normalizer.samples_to_trim
-
- # Find non silent portion and trim
- start_index, end_index = normalizer.find_first_last_non_silent(
- audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
- )
-
- audio_chunk.audio = audio_chunk.audio[start_index:end_index]
- trimed_samples += start_index
-
+ audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
+ trimed_samples+=normalizer.samples_to_trim
+
+ # Find non silent portion and trim
+ start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
+
+ audio_chunk.audio=audio_chunk.audio[start_index:end_index]
+ trimed_samples+=start_index
+
if audio_chunk.word_timestamps is not None:
for timestamp in audio_chunk.word_timestamps:
- timestamp.start_time -= trimed_samples / 24000
- timestamp.end_time -= trimed_samples / 24000
+ timestamp.start_time-=trimed_samples / 24000
+ timestamp.end_time-=trimed_samples / 24000
return audio_chunk
+
\ No newline at end of file
diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py
index de9c84e3..71dcd32d 100644
--- a/api/src/services/streaming_audio_writer.py
+++ b/api/src/services/streaming_audio_writer.py
@@ -4,12 +4,11 @@
from io import BytesIO
from typing import Optional
-import av
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment
-
+import av
class StreamingAudioWriter:
"""Handles streaming audio format conversions"""
@@ -19,49 +18,18 @@ def __init__(self, format: str, sample_rate: int, channels: int = 1):
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
- self.pts = 0
+ self.pts=0
- codec_map = {
- "wav": "pcm_s16le",
- "mp3": "mp3",
- "opus": "libopus",
- "flac": "flac",
- "aac": "aac",
- }
+ codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
# Format-specific setup
- if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
+ if self.format in ["wav", "opus","flac","mp3","aac","pcm"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
- container_options = {}
- # Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
- if self.format == 'mp3':
- # Disable Xing VBR header
- container_options = {'write_xing': '0'}
- logger.debug("Disabling Xing VBR header for MP3 encoding.")
-
- self.container = av.open(
- self.output_buffer,
- mode="w",
- format=self.format if self.format != "aac" else "adts",
- options=container_options # Pass options here
- )
- self.stream = self.container.add_stream(
- codec_map[self.format],
- rate=self.sample_rate,
- layout="mono" if self.channels == 1 else "stereo",
- )
- # Set bit_rate only for codecs where it's applicable and useful
- if self.format in ['mp3', 'aac', 'opus']:
- self.stream.bit_rate = 128000
+ self.container = av.open(self.output_buffer, mode="w", format=self.format)
+ self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
+ self.stream.bit_rate = 128000
else:
- raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
-
- def close(self):
- if hasattr(self, "container"):
- self.container.close()
-
- if hasattr(self, "output_buffer"):
- self.output_buffer.close()
+ raise ValueError(f"Unsupported format: {format}")
def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
@@ -75,18 +43,12 @@ def write_chunk(
if finalize:
if self.format != "pcm":
- # Flush stream encoder
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
-
- # Closing the container handles writing the trailer and finalizing the file.
- # No explicit flush method is available or needed here.
- logger.debug("Muxed final packets.")
-
- # Get the final bytes from the buffer *before* closing it
- data = self.output_buffer.getvalue()
- self.close() # Close container and buffer
+
+ data=self.output_buffer.getvalue()
+ self.container.close()
return data
if audio_data is None or len(audio_data) == 0:
@@ -96,21 +58,19 @@ def write_chunk(
# Write raw bytes
return audio_data.tobytes()
else:
- frame = av.AudioFrame.from_ndarray(
- audio_data.reshape(1, -1),
- format="s16",
- layout="mono" if self.channels == 1 else "stereo",
- )
- frame.sample_rate = self.sample_rate
+ frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
+ frame.sample_rate=self.sample_rate
+
frame.pts = self.pts
self.pts += frame.samples
-
+
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
-
+
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
- return data
+ return data
+
diff --git a/api/src/services/temp_manager.py b/api/src/services/temp_manager.py
index 4d92a9e1..98d49888 100644
--- a/api/src/services/temp_manager.py
+++ b/api/src/services/temp_manager.py
@@ -81,36 +81,26 @@ def __init__(self, format: str):
self.format = format
self.temp_file = None
self._finalized = False
- self._write_error = False # Flag to track if we've had a write error
async def __aenter__(self):
"""Async context manager entry"""
- try:
- # Clean up old files first
- await cleanup_temp_files()
-
- # Create temp file with proper extension
- await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
- temp = tempfile.NamedTemporaryFile(
- dir=settings.temp_file_dir,
- delete=False,
- suffix=f".{self.format}",
- mode="wb",
- )
- self.temp_file = await aiofiles.open(temp.name, mode="wb")
- self.temp_path = temp.name
- temp.close() # Close sync file, we'll use async version
-
- # Generate download path immediately
- self.download_path = f"/download/{os.path.basename(self.temp_path)}"
- except Exception as e:
- # Handle permission issues or other errors gracefully
- logger.error(f"Failed to create temp file: {e}")
- self._write_error = True
- # Set a placeholder path so the API can still function
- self.temp_path = f"unavailable_{self.format}"
- self.download_path = f"/download/{self.temp_path}"
-
+ # Clean up old files first
+ await cleanup_temp_files()
+
+ # Create temp file with proper extension
+ await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
+ temp = tempfile.NamedTemporaryFile(
+ dir=settings.temp_file_dir,
+ delete=False,
+ suffix=f".{self.format}",
+ mode="wb",
+ )
+ self.temp_file = await aiofiles.open(temp.name, mode="wb")
+ self.temp_path = temp.name
+ temp.close() # Close sync file, we'll use async version
+
+ # Generate download path immediately
+ self.download_path = f"/download/{os.path.basename(self.temp_path)}"
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -121,7 +111,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
self._finalized = True
except Exception as e:
logger.error(f"Error closing temp file: {e}")
- self._write_error = True
async def write(self, chunk: bytes) -> None:
"""Write a chunk of audio data
@@ -132,17 +121,8 @@ async def write(self, chunk: bytes) -> None:
if self._finalized:
raise RuntimeError("Cannot write to finalized temp file")
- # Skip writing if we've already encountered an error
- if self._write_error or not self.temp_file:
- return
-
- try:
- await self.temp_file.write(chunk)
- await self.temp_file.flush()
- except Exception as e:
- # Handle permission issues or other errors gracefully
- logger.error(f"Failed to write to temp file: {e}")
- self._write_error = True
+ await self.temp_file.write(chunk)
+ await self.temp_file.flush()
async def finalize(self) -> str:
"""Close temp file and return download path
@@ -153,18 +133,7 @@ async def finalize(self) -> str:
if self._finalized:
raise RuntimeError("Temp file already finalized")
- # Skip finalizing if we've already encountered an error
- if self._write_error or not self.temp_file:
- self._finalized = True
- return self.download_path
-
- try:
- await self.temp_file.close()
- self._finalized = True
- except Exception as e:
- # Handle permission issues or other errors gracefully
- logger.error(f"Failed to finalize temp file: {e}")
- self._write_error = True
- self._finalized = True
+ await self.temp_file.close()
+ self._finalized = True
- return self.download_path
+ return f"/download/{os.path.basename(self.temp_path)}"
diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py
index 1b1c9f7e..84c3694d 100644
--- a/api/src/services/text_processing/normalizer.py
+++ b/api/src/services/text_processing/normalizer.py
@@ -4,15 +4,10 @@
Converts them into a format suitable for text-to-speech processing.
"""
-import math
import re
from functools import lru_cache
-from typing import List, Optional, Union
-
import inflect
from numpy import number
-# from text_to_num import text2num
-from torch import mul
from ...structures.schemas import NormalizationOptions
@@ -57,101 +52,28 @@
"uk",
"us",
"io",
- "co",
]
VALID_UNITS = {
- "m": "meter",
- "cm": "centimeter",
- "mm": "millimeter",
- "km": "kilometer",
- "in": "inch",
- "ft": "foot",
- "yd": "yard",
- "mi": "mile", # Length
- "g": "gram",
- "kg": "kilogram",
- "mg": "milligram", # Mass
- "s": "second",
- "ms": "millisecond",
- "min": "minutes",
- "h": "hour", # Time
- "l": "liter",
- "ml": "mililiter",
- "cl": "centiliter",
- "dl": "deciliter", # Volume
- "kph": "kilometer per hour",
- "mph": "mile per hour",
- "mi/h": "mile per hour",
- "m/s": "meter per second",
- "km/h": "kilometer per hour",
- "mm/s": "milimeter per second",
- "cm/s": "centimeter per second",
- "ft/s": "feet per second",
- "cm/h": "centimeter per day", # Speed
- "°c": "degree celsius",
- "c": "degree celsius",
- "°f": "degree fahrenheit",
- "f": "degree fahrenheit",
- "k": "kelvin", # Temperature
- "pa": "pascal",
- "kpa": "kilopascal",
- "mpa": "megapascal",
- "atm": "atmosphere", # Pressure
- "hz": "hertz",
- "khz": "kilohertz",
- "mhz": "megahertz",
- "ghz": "gigahertz", # Frequency
- "v": "volt",
- "kv": "kilovolt",
- "mv": "mergavolt", # Voltage
- "a": "amp",
- "ma": "megaamp",
- "ka": "kiloamp", # Current
- "w": "watt",
- "kw": "kilowatt",
- "mw": "megawatt", # Power
- "j": "joule",
- "kj": "kilojoule",
- "mj": "megajoule", # Energy
- "Ω": "ohm",
- "kΩ": "kiloohm",
- "mΩ": "megaohm", # Resistance (Ohm)
- "f": "farad",
- "µf": "microfarad",
- "nf": "nanofarad",
- "pf": "picofarad", # Capacitance
- "b": "bit",
- "kb": "kilobit",
- "mb": "megabit",
- "gb": "gigabit",
- "tb": "terabit",
- "pb": "petabit", # Data size
- "kbps": "kilobit per second",
- "mbps": "megabit per second",
- "gbps": "gigabit per second",
- "tbps": "terabit per second",
- "px": "pixel", # CSS units
-}
-
-SYMBOL_REPLACEMENTS = {
- '~': ' ',
- '@': ' at ',
- '#': ' number ',
- '$': ' dollar ',
- '%': ' percent ',
- '^': ' ',
- '&': ' and ',
- '*': ' ',
- '_': ' ',
- '|': ' ',
- '\\': ' ',
- '/': ' slash ',
- '=': ' equals ',
- '+': ' plus ',
+ "m":"meter", "cm":"centimeter", "mm":"millimeter", "km":"kilometer", "in":"inch", "ft":"foot", "yd":"yard", "mi":"mile", # Length
+ "g":"gram", "kg":"kilogram", "mg":"miligram", # Mass
+ "s":"second", "ms":"milisecond", "min":"minutes", "h":"hour", # Time
+ "l":"liter", "ml":"mililiter", "cl":"centiliter", "dl":"deciliter", # Volume
+ "kph":"kilometer per hour", "mph":"mile per hour","mi/h":"mile per hour", "m/s":"meter per second", "km/h":"kilometer per hour", "mm/s":"milimeter per second","cm/s":"centimeter per second", "ft/s":"feet per second","cm/h":"centimeter per day", # Speed
+ "°c":"degree celsius","c":"degree celsius", "°f":"degree fahrenheit","f":"degree fahrenheit", "k":"kelvin", # Temperature
+ "pa":"pascal", "kpa":"kilopascal", "mpa":"megapascal", "atm":"atmosphere", # Pressure
+ "hz":"hertz", "khz":"kilohertz", "mhz":"megahertz", "ghz":"gigahertz", # Frequency
+ "v":"volt", "kv":"kilovolt", "mv":"mergavolt", # Voltage
+ "a":"amp", "ma":"megaamp", "ka":"kiloamp", # Current
+ "w":"watt", "kw":"kilowatt", "mw":"megawatt", # Power
+ "j":"joule", "kj":"kilojoule", "mj":"megajoule", # Energy
+ "Ω":"ohm", "kΩ":"kiloohm", "mΩ":"megaohm", # Resistance (Ohm)
+ "f":"farad", "µf":"microfarad", "nf":"nanofarad", "pf":"picofarad", # Capacitance
+ "b":"bit", "kb":"kilobit", "mb":"megabit", "gb":"gigabit", "tb":"terabit", "pb":"petabit", # Data size
+ "kbps":"kilobit per second","mbps":"megabit per second","gbps":"gigabit per second","tbps":"terabit per second",
+ "px":"pixel" # CSS units
}
-MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")}
# Pre-compiled regex patterns for performance
EMAIL_PATTERN = re.compile(
@@ -164,133 +86,72 @@
re.IGNORECASE,
)
-UNIT_PATTERN = re.compile(
- r"((? str:
+ """Handle number splitting for various formats"""
+ num = num.group()
+ if "." in num:
+ return num
+ elif ":" in num:
+ h, m = [int(n) for n in num.split(":")]
+ if m == 0:
+ return f"{h} o'clock"
+ elif m < 10:
+ return f"{h} oh {m}"
+ return f"{h} {m}"
+ year = int(num[:4])
+ if year < 1100 or year % 1000 < 10:
+ return num
+ left, right = num[:2], int(num[2:4])
+ s = "s" if num.endswith("s") else ""
+ if 100 <= year % 1000 <= 999:
+ if right == 0:
+ return f"{left} hundred{s}"
+ elif right < 10:
+ return f"{left} oh {right}{s}"
+ return f"{left} {right}{s}"
def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
- unit_string = u.group(6).strip()
- unit = unit_string
-
+ unit_string=u.group(6).strip()
+ unit=unit_string
+
if unit_string.lower() in VALID_UNITS:
- unit = VALID_UNITS[unit_string.lower()].split(" ")
-
+ unit=VALID_UNITS[unit_string.lower()].split(" ")
+
# Handles the B vs b case
if unit[0].endswith("bit"):
- b_case = unit_string[min(1, len(unit_string) - 1)]
+ b_case=unit_string[min(1,len(unit_string) - 1)]
if b_case == "B":
- unit[0] = unit[0][:-3] + "byte"
-
- number = u.group(1).strip()
- unit[0] = INFLECT_ENGINE.no(unit[0], number)
+ unit[0]=unit[0][:-3] + "byte"
+
+ number=u.group(1).strip()
+ unit[0]=INFLECT_ENGINE.no(unit[0],number)
return " ".join(unit)
-
-def conditional_int(number: float, threshold: float = 0.00001):
- if abs(round(number) - number) < threshold:
- return int(round(number))
- return number
-
-
-def translate_multiplier(multiplier: str) -> str:
- """Translate multiplier abrevations to words"""
-
- multiplier_translation = {
- "k": "thousand",
- "m": "million",
- "b": "billion",
- "t": "trillion",
- }
- if multiplier.lower() in multiplier_translation:
- return multiplier_translation[multiplier.lower()]
- return multiplier.strip()
-
-
-def split_four_digit(number: float):
- part1 = str(conditional_int(number))[:2]
- part2 = str(conditional_int(number))[2:]
- return f"{INFLECT_ENGINE.number_to_words(part1)} {INFLECT_ENGINE.number_to_words(part2)}"
-
-
-def handle_numbers(n: re.Match[str]) -> str:
- number = n.group(2)
-
- try:
- number = float(number)
- except:
- return n.group()
-
- if n.group(1) == "-":
- number *= -1
-
- multiplier = translate_multiplier(n.group(3))
-
- number = conditional_int(number)
- if multiplier != "":
- multiplier = f" {multiplier}"
- else:
- if (
- number % 1 == 0
- and len(str(number)) == 4
- and number > 1500
- and number % 1000 > 9
- ):
- return split_four_digit(number)
-
- return f"{INFLECT_ENGINE.number_to_words(number)}{multiplier}"
-
-
def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form"""
-
- bill, coin = MONEY_UNITS[m.group(2)]
-
- number = m.group(3)
-
- try:
- number = float(number)
- except:
- return m.group()
-
- if m.group(1) == "-":
- number *= -1
-
- multiplier = translate_multiplier(m.group(4))
-
- if multiplier != "":
- multiplier = f" {multiplier}"
-
- if number % 1 == 0 or multiplier != "":
- text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
- else:
- sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
-
- text_number = f"{INFLECT_ENGINE.number_to_words(int(math.floor(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
-
- return text_number
+ m = m.group()
+ bill = "dollar" if m[0] == "$" else "pound"
+ if m[-1].isalpha():
+ return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}s"
+ elif "." not in m:
+ s = "" if m[1:] == "1" else "s"
+ return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}{s}"
+ b, c = m[1:].split(".")
+ s = "" if b == "1" else "s"
+ c = int(c.ljust(2, "0"))
+ coins = (
+ f"cent{'' if c == 1 else 's'}"
+ if m[0] == "$"
+ else ("penny" if c == 1 else "pence")
+ )
+ return f"{INFLECT_ENGINE.number_to_words(b)} {bill}{s} and {INFLECT_ENGINE.number_to_words(c)} {coins}"
def handle_decimal(num: re.Match[str]) -> str:
@@ -356,59 +217,35 @@ def handle_url(u: re.Match[str]) -> str:
# Clean up extra spaces
return re.sub(r"\s+", " ", url).strip()
-
def handle_phone_number(p: re.Match[str]) -> str:
- p = list(p.groups())
-
- country_code = ""
+ p=list(p.groups())
+
+ country_code=""
if p[0] is not None:
- p[0] = p[0].replace("+", "")
+ p[0]=p[0].replace("+","")
country_code += INFLECT_ENGINE.number_to_words(p[0])
-
- area_code = INFLECT_ENGINE.number_to_words(
- p[2].replace("(", "").replace(")", ""), group=1, comma=""
- )
-
- telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
-
- line_number = INFLECT_ENGINE.number_to_words(p[4], group=1, comma="")
-
- return ",".join([country_code, area_code, telephone_prefix, line_number])
-
+
+ area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
+
+ telephone_prefix=INFLECT_ENGINE.number_to_words(p[3],group=1,comma="")
+
+ line_number=INFLECT_ENGINE.number_to_words(p[4],group=1,comma="")
+
+ return ",".join([country_code,area_code,telephone_prefix,line_number])
def handle_time(t: re.Match[str]) -> str:
- t = t.groups()
-
- time_parts = t[0].split(":")
-
- numbers = []
- numbers.append(INFLECT_ENGINE.number_to_words(time_parts[0].strip()))
-
- minute_number = INFLECT_ENGINE.number_to_words(time_parts[1].strip())
- if int(time_parts[1]) < 10:
- if int(time_parts[1]) != 0:
- numbers.append(f"oh {minute_number}")
- else:
- numbers.append(minute_number)
-
- half = ""
- if len(time_parts) > 2:
- seconds_number = INFLECT_ENGINE.number_to_words(time_parts[2].strip())
- second_word = INFLECT_ENGINE.plural("second", int(time_parts[2].strip()))
- numbers.append(f"and {seconds_number} {second_word}")
- else:
- if t[2] is not None:
- half = " " + t[2].strip()
- else:
- if int(time_parts[1]) == 0:
- numbers.append("o'clock")
-
- return " ".join(numbers) + half
-
+ t=t.groups()
+
+ numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
+
+ half=""
+ if t[2] is not None:
+ half=t[2].strip()
+
+ return numbers + half
-def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
+def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""
-
# Handle email addresses first if enabled
if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
@@ -419,21 +256,17 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
- text = UNIT_PATTERN.sub(handle_units, text)
-
+ text=UNIT_PATTERN.sub(handle_units,text)
+
# Replace optional pluralization
if normalization_options.optional_pluralization_normalization:
- text = re.sub(r"\(s\)", "s", text)
-
+ text = re.sub(r"\(s\)","s",text)
+
# Replace phone numbers:
if normalization_options.phone_normalization:
- text = re.sub(
- r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
- handle_phone_number,
- text,
- )
-
- # Replace quotes and brackets (additional cleanup)
+ text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",handle_phone_number,text)
+
+ # Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
@@ -442,22 +275,14 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
text = text.replace(a, b + " ")
- # Handle simple time in the format of HH:MM:SS (am/pm)
- text = TIME_PATTERN.sub(
- handle_time,
- text,
- )
+ # Handle simple time in the format of HH:MM:SS
+ text = TIME_PATTERN.sub(handle_time, text, )
# Clean up whitespace
text = re.sub(r"[^\S \n]", " ", text)
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
- # Handle special characters that might cause audio artifacts first
- # Replace newlines with spaces (or pauses if needed)
- text = text.replace('\n', ' ')
- text = text.replace('\r', ' ')
-
# Handle titles and abbreviations
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
@@ -468,23 +293,21 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
# Handle common words
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
- # Handle numbers and money BEFORE replacing special characters
+ # Handle numbers and money
text = re.sub(r"(?<=\d),(?=\d)", "", text)
-
- text = MONEY_PATTERN.sub(
+
+ text = re.sub(
+ r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
handle_money,
text,
)
-
- text = NUMBER_PATTERN.sub(handle_numbers, text)
-
+
+ text = re.sub(
+ r"\d*\.\d+|\b\d{4}s?\b|(? st
text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
)
- text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
-
- text = re.sub(r"\s{2,}", " ", text)
+ text = re.sub( r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
- return text
+ return text.strip()
diff --git a/api/src/services/text_processing/phonemizer.py b/api/src/services/text_processing/phonemizer.py
index dabf3284..5a50d64e 100644
--- a/api/src/services/text_processing/phonemizer.py
+++ b/api/src/services/text_processing/phonemizer.py
@@ -4,7 +4,6 @@
import phonemizer
from .normalizer import normalize_text
-from ...structures.schemas import NormalizationOptions
phonemizers = {}
@@ -76,7 +75,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
Phonemizer backend instance
"""
# Map language codes to espeak language codes
- lang_map = {"a": "en-us", "b": "en-gb", "z": "z"}
+ lang_map = {"a": "en-us", "b": "en-gb"}
if language not in lang_map:
raise ValueError(f"Unsupported language code: {language}")
@@ -84,24 +83,20 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
return EspeakBackend(lang_map[language])
-def phonemize(text: str, language: str = "a") -> str:
+def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
"""Convert text to phonemes
Args:
text: Text to convert to phonemes
language: Language code ('a' for US English, 'b' for British English)
+ normalize: Whether to normalize text before phonemization
Returns:
Phonemized text
"""
global phonemizers
-
- # Strip input text first to remove problematic leading/trailing spaces
- text = text.strip()
-
+ if normalize:
+ text = normalize_text(text)
if language not in phonemizers:
phonemizers[language] = create_phonemizer(language)
-
- result = phonemizers[language].phonemize(text)
- # Final strip to ensure no leading/trailing spaces in phonemes
- return result.strip()
+ return phonemizers[language].phonemize(text)
diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py
index 483618f9..0d8d36c6 100644
--- a/api/src/services/text_processing/text_processor.py
+++ b/api/src/services/text_processing/text_processor.py
@@ -2,23 +2,18 @@
import re
import time
-from typing import AsyncGenerator, Dict, List, Tuple, Optional
+from typing import AsyncGenerator, Dict, List, Tuple
from loguru import logger
from ...core.config import settings
-from ...structures.schemas import NormalizationOptions
from .normalizer import normalize_text
from .phonemizer import phonemize
from .vocabulary import tokenize
+from ...structures.schemas import NormalizationOptions
# Pre-compiled regex patterns for performance
-# Updated regex to be more strict and avoid matching isolated brackets
-# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
-CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))")
-# Pattern to find pause tags like [pause:0.5s]
-PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
-
+CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
def process_text_chunk(
text: str, language: str = "a", skip_phonemize: bool = False
@@ -35,12 +30,6 @@ def process_text_chunk(
"""
start_time = time.time()
- # Strip input text to remove any leading/trailing spaces that could cause artifacts
- text = text.strip()
-
- if not text:
- return []
-
if skip_phonemize:
# Input is already phonemes, just tokenize
t0 = time.time()
@@ -52,9 +41,9 @@ def process_text_chunk(
t1 = time.time()
t0 = time.time()
- phonemes = phonemize(text, language)
- # Strip phonemes result to ensure no extra spaces
- phonemes = phonemes.strip()
+ phonemes = phonemize(
+ text, language, normalize=False
+ ) # Already normalized
t1 = time.time()
t0 = time.time()
@@ -99,227 +88,180 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language)
-def get_sentence_info(
- text: str, lang_code: str = "a"
-) -> List[Tuple[str, List[int], int]]:
- """Process all sentences and return info"""
- # Detect Chinese text
- is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
- if is_chinese:
- # Split using Chinese punctuation
- sentences = re.split(r"([,。!?;])+", text)
- else:
- sentences = re.split(r"([.!?;:])(?=\s|$)", text)
-
+def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
+ """Process all sentences and return info."""
+ sentences = re.split(r"([.!?;:])(?=\s|$)", text)
+ phoneme_length, min_value = len(custom_phenomes_list), 0
+
results = []
for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
+ for replaced in range(min_value, phoneme_length):
+ current_id = f"|custom_phonemes_{replaced}|/>"
+ if current_id in sentence:
+ sentence = sentence.replace(current_id, custom_phenomes_list.pop(current_id))
+ min_value += 1
+
+
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
+
if not sentence:
continue
+
full = sentence + punct
- # Strip the full sentence to remove any leading/trailing spaces before processing
- full = full.strip()
- if not full: # Skip if empty after stripping
- continue
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
- return results
+ return results
-def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
+def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str,str]) -> str:
latest_id = f"|custom_phonemes_{len(phenomes_list)}|/>"
phenomes_list[latest_id] = s.group(0).strip()
return latest_id
-
async def smart_split(
- text: str,
+ text: str,
max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a",
- normalization_options: NormalizationOptions = NormalizationOptions(),
-) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
- """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
-
- Yields:
- Tuple of (text_chunk, tokens, pause_duration_s).
- If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
- Otherwise, it's a text chunk containing the original text.
- """
+ normalization_options: NormalizationOptions = NormalizationOptions()
+) -> AsyncGenerator[Tuple[str, List[int]], None]:
+ """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars")
- # --- Step 1: Split by Pause Tags FIRST ---
- # This operates on the raw input text
- parts = PAUSE_TAG_PATTERN.split(text)
- logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
+ custom_phoneme_list = {}
- part_idx = 0
- while part_idx < len(parts):
- text_part_raw = parts[part_idx] # This part is raw text
- part_idx += 1
+ # Normalize text
+ if settings.advanced_text_normalization and normalization_options.normalize:
+ if lang_code in ["a","b","en-us","en-gb"]:
+ text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
+ text=normalize_text(text,normalization_options)
+ else:
+ logger.info("Skipping text normalization as it is only supported for english")
- # --- Process Text Part ---
- if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
- # Strip leading and trailing spaces to prevent pause tag splitting artifacts
- text_part_raw = text_part_raw.strip()
+ # Process all sentences
+ sentences = get_sentence_info(text, custom_phoneme_list)
- # Normalize text (original logic)
- processed_text = text_part_raw
- if settings.advanced_text_normalization and normalization_options.normalize:
- if lang_code in ["a", "b", "en-us", "en-gb"]:
- processed_text = CUSTOM_PHONEMES.split(processed_text)
- for index in range(0, len(processed_text), 2):
- processed_text[index] = normalize_text(processed_text[index], normalization_options)
+ current_chunk = []
+ current_tokens = []
+ current_count = 0
-
- processed_text = "".join(processed_text).strip()
+ for sentence, tokens, count in sentences:
+ # Handle sentences that exceed max tokens
+ if count > max_tokens:
+ # Yield current chunk if any
+ if current_chunk:
+ chunk_text = " ".join(current_chunk)
+ chunk_count += 1
+ logger.debug(
+ f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
+ )
+ yield chunk_text, current_tokens
+ current_chunk = []
+ current_tokens = []
+ current_count = 0
+
+ # Split long sentence on commas
+ clauses = re.split(r"([,])", sentence)
+ clause_chunk = []
+ clause_tokens = []
+ clause_count = 0
+
+ for j in range(0, len(clauses), 2):
+ clause = clauses[j].strip()
+ comma = clauses[j + 1] if j + 1 < len(clauses) else ""
+
+ if not clause:
+ continue
+
+ full_clause = clause + comma
+
+ tokens = process_text_chunk(full_clause)
+ count = len(tokens)
+
+ # If adding clause keeps us under max and not optimal yet
+ if (
+ clause_count + count <= max_tokens
+ and clause_count + count <= settings.target_max_tokens
+ ):
+ clause_chunk.append(full_clause)
+ clause_tokens.extend(tokens)
+ clause_count += count
else:
- logger.info(
- "Skipping text normalization as it is only supported for english"
- )
-
- # Process all sentences (original logic)
- sentences = get_sentence_info(processed_text, lang_code=lang_code)
-
- current_chunk = []
- current_tokens = []
- current_count = 0
-
- for sentence, tokens, count in sentences:
- # Handle sentences that exceed max tokens (original logic)
- if count > max_tokens:
- # Yield current chunk if any
- if current_chunk:
- chunk_text = " ".join(current_chunk).strip()
- chunk_count += 1
- logger.debug(
- f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
- )
- yield chunk_text, current_tokens, None
- current_chunk = []
- current_tokens = []
- current_count = 0
-
- # Split long sentence on commas (original logic)
- clauses = re.split(r"([,])", sentence)
- clause_chunk = []
- clause_tokens = []
- clause_count = 0
-
- for j in range(0, len(clauses), 2):
- clause = clauses[j].strip()
- comma = clauses[j + 1] if j + 1 < len(clauses) else ""
-
- if not clause:
- continue
-
- full_clause = clause + comma
-
- tokens = process_text_chunk(full_clause)
- count = len(tokens)
-
- # If adding clause keeps us under max and not optimal yet
- if (
- clause_count + count <= max_tokens
- and clause_count + count <= settings.target_max_tokens
- ):
- clause_chunk.append(full_clause)
- clause_tokens.extend(tokens)
- clause_count += count
- else:
- # Yield clause chunk if we have one
- if clause_chunk:
- chunk_text = " ".join(clause_chunk).strip()
- chunk_count += 1
- logger.debug(
- f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
- )
- yield chunk_text, clause_tokens, None
- clause_chunk = [full_clause]
- clause_tokens = tokens
- clause_count = count
-
- # Don't forget last clause chunk
+ # Yield clause chunk if we have one
if clause_chunk:
- chunk_text = " ".join(clause_chunk).strip()
+ chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(
- f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
- )
- yield chunk_text, clause_tokens, None
-
- # Regular sentence handling (original logic)
- elif (
- current_count >= settings.target_min_tokens
- and current_count + count > settings.target_max_tokens
- ):
- # If we have a good sized chunk and adding next sentence exceeds target,
- # yield current chunk and start new one
- chunk_text = " ".join(current_chunk).strip()
- chunk_count += 1
- logger.info(
- f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
- )
- yield chunk_text, current_tokens, None
- current_chunk = [sentence]
- current_tokens = tokens
- current_count = count
- elif current_count + count <= settings.target_max_tokens:
- # Keep building chunk while under target max
- current_chunk.append(sentence)
- current_tokens.extend(tokens)
- current_count += count
- elif (
- current_count + count <= max_tokens
- and current_count < settings.target_min_tokens
- ):
- # Only exceed target max if we haven't reached minimum size yet
- current_chunk.append(sentence)
- current_tokens.extend(tokens)
- current_count += count
- else:
- # Yield current chunk and start new one
- if current_chunk:
- chunk_text = " ".join(current_chunk).strip()
- chunk_count += 1
- logger.info(
- f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
+ f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
- yield chunk_text, current_tokens, None
- current_chunk = [sentence]
- current_tokens = tokens
- current_count = count
-
- # Don't forget the last chunk for this text part
+ yield chunk_text, clause_tokens
+ clause_chunk = [full_clause]
+ clause_tokens = tokens
+ clause_count = count
+
+ # Don't forget last clause chunk
+ if clause_chunk:
+ chunk_text = " ".join(clause_chunk)
+ chunk_count += 1
+ logger.debug(
+ f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
+ )
+ yield chunk_text, clause_tokens
+
+ # Regular sentence handling
+ elif (
+ current_count >= settings.target_min_tokens
+ and current_count + count > settings.target_max_tokens
+ ):
+ # If we have a good sized chunk and adding next sentence exceeds target,
+ # yield current chunk and start new one
+ chunk_text = " ".join(current_chunk)
+ chunk_count += 1
+ logger.info(
+ f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
+ )
+ yield chunk_text, current_tokens
+ current_chunk = [sentence]
+ current_tokens = tokens
+ current_count = count
+ elif current_count + count <= settings.target_max_tokens:
+ # Keep building chunk while under target max
+ current_chunk.append(sentence)
+ current_tokens.extend(tokens)
+ current_count += count
+ elif (
+ current_count + count <= max_tokens
+ and current_count < settings.target_min_tokens
+ ):
+ # Only exceed target max if we haven't reached minimum size yet
+ current_chunk.append(sentence)
+ current_tokens.extend(tokens)
+ current_count += count
+ else:
+ # Yield current chunk and start new one
if current_chunk:
- chunk_text = " ".join(current_chunk).strip()
+ chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(
- f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
+ f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
- yield chunk_text, current_tokens, None
-
- # --- Handle Pause Part ---
- # Check if the next part is a pause duration string
- if part_idx < len(parts):
- duration_str = parts[part_idx]
- # Check if it looks like a valid number string captured by the regex group
- if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
- part_idx += 1 # Consume the duration string as it's been processed
- try:
- duration = float(duration_str)
- if duration > 0:
- chunk_count += 1
- logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
- yield "", [], duration # Yield pause chunk
- except (ValueError, TypeError):
- # This case should be rare if re.fullmatch passed, but handle anyway
- logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
+ yield chunk_text, current_tokens
+ current_chunk = [sentence]
+ current_tokens = tokens
+ current_count = count
+
+ # Don't forget the last chunk
+ if current_chunk:
+ chunk_text = " ".join(current_chunk)
+ chunk_count += 1
+ logger.info(
+ f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
+ )
+ yield chunk_text, current_tokens
- # --- End of parts loop ---
total_time = time.time() - start_time
logger.info(
- f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
+ f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
)
diff --git a/api/src/services/text_processing/vocabulary.py b/api/src/services/text_processing/vocabulary.py
index d6d7863e..7a128924 100644
--- a/api/src/services/text_processing/vocabulary.py
+++ b/api/src/services/text_processing/vocabulary.py
@@ -23,8 +23,6 @@ def tokenize(phonemes: str) -> list[int]:
Returns:
List of token IDs
"""
- # Strip phonemes to remove leading/trailing spaces that could cause artifacts
- phonemes = phonemes.strip()
return [i for i in map(VOCAB.get, phonemes) if i is not None]
diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py
index 46c2fb4c..a115c189 100644
--- a/api/src/services/tts_service.py
+++ b/api/src/services/tts_service.py
@@ -2,27 +2,24 @@
import asyncio
import os
-import re
import tempfile
import time
from typing import AsyncGenerator, List, Optional, Tuple, Union
+from ..inference.base import AudioChunk
import numpy as np
import torch
from kokoro import KPipeline
from loguru import logger
from ..core.config import settings
-from ..inference.base import AudioChunk
from ..inference.kokoro_v1 import KokoroV1
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
-from ..structures.schemas import NormalizationOptions
from .audio import AudioNormalizer, AudioService
-from .streaming_audio_writer import StreamingAudioWriter
from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split
-
+from ..structures.schemas import NormalizationOptions
class TTSService:
"""Text-to-speech service."""
@@ -51,11 +48,9 @@ async def _process_chunk(
voice_name: str,
voice_path: str,
speed: float,
- writer: StreamingAudioWriter,
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
- volume_multiplier: Optional[float] = 1.0,
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False,
@@ -67,16 +62,15 @@ async def _process_chunk(
if is_last:
# Skip format conversion for raw audio mode
if not output_format:
- yield AudioChunk(np.array([], dtype=np.int16), output=b"")
+ yield AudioChunk(np.array([], dtype=np.int16),output=b'')
return
chunk_data = await AudioService.convert_audio(
- AudioChunk(
- np.array([], dtype=np.float32)
- ), # Dummy data for type checking
+ AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
+ 24000,
output_format,
- writer,
speed,
"",
+ is_first_chunk=False,
normalizer=normalizer,
is_last_chunk=True,
)
@@ -92,7 +86,7 @@ async def _process_chunk(
# Generate audio using pre-warmed model
if isinstance(backend, KokoroV1):
- chunk_index = 0
+ chunk_index=0
# For Kokoro V1, pass text and voice info with lang_code
async for chunk_data in self.model_manager.generate(
chunk_text,
@@ -101,16 +95,16 @@ async def _process_chunk(
lang_code=lang_code,
return_timestamps=return_timestamps,
):
- chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes
if output_format:
try:
chunk_data = await AudioService.convert_audio(
chunk_data,
+ 24000,
output_format,
- writer,
speed,
chunk_text,
+ is_first_chunk=is_first and chunk_index == 0,
is_last_chunk=is_last,
normalizer=normalizer,
)
@@ -118,23 +112,23 @@ async def _process_chunk(
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
- chunk_data = AudioService.trim_audio(
- chunk_data, chunk_text, speed, is_last, normalizer
- )
+ chunk_data = AudioService.trim_audio(chunk_data,
+ chunk_text,
+ speed,
+ is_last,
+ normalizer)
yield chunk_data
- chunk_index += 1
+ chunk_index+=1
else:
+
# For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device
)
chunk_data = await self.model_manager.generate(
- tokens,
- voice_tensor,
- speed=speed,
- return_timestamps=return_timestamps,
+ tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
)
-
+
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
return
@@ -143,17 +137,16 @@ async def _process_chunk(
logger.error("Model generated empty audio chunk")
return
- chunk_data.audio*=volume_multiplier
-
# For streaming, convert to bytes
if output_format:
try:
chunk_data = await AudioService.convert_audio(
chunk_data,
+ 24000,
output_format,
- writer,
speed,
chunk_text,
+ is_first_chunk=is_first,
normalizer=normalizer,
is_last_chunk=is_last,
)
@@ -161,22 +154,16 @@ async def _process_chunk(
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
- trimmed = AudioService.trim_audio(
- chunk_data, chunk_text, speed, is_last, normalizer
- )
+ trimmed = AudioService.trim_audio(chunk_data,
+ chunk_text,
+ speed,
+ is_last,
+ normalizer)
yield trimmed
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
- async def _load_voice_from_path(self, path: str, weight: float):
- # Check if the path is None and raise a ValueError if it is not
- if not path:
- raise ValueError(f"Voice not found at path: {path}")
-
- logger.debug(f"Loading voice tensor from path: {path}")
- return torch.load(path, map_location="cpu") * weight
-
- async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
+ async def _get_voice_path(self, voice: str) -> Tuple[str, str]:
"""Get voice path, handling combined voices.
Args:
@@ -189,68 +176,64 @@ async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
RuntimeError: If voice not found
"""
try:
- # Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
- split_voice = re.split(r"([-+])", voice)
-
- # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
- if len(split_voice) == 1:
- # Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
- if (
- "(" not in voice and ")" not in voice
- ) or settings.voice_weight_normalization == True:
- path = await self._voice_manager.get_voice_path(voice)
+ # Check if it's a combined voice
+ if "+" in voice:
+ # Split on + but preserve any parentheses
+ voice_parts = []
+ weights = []
+ for part in voice.split("+"):
+ part = part.strip()
+ if not part:
+ continue
+ # Extract voice name and weight if present
+ if "(" in part and ")" in part:
+ voice_name = part.split("(")[0].strip()
+ weight = float(part.split("(")[1].split(")")[0])
+ else:
+ voice_name = part
+ weight = 1.0
+ voice_parts.append(voice_name)
+ weights.append(weight)
+
+ if len(voice_parts) < 2:
+ raise RuntimeError(f"Invalid combined voice name: {voice}")
+
+ # Normalize weights to sum to 1
+ total_weight = sum(weights)
+ weights = [w / total_weight for w in weights]
+
+ # Load and combine voices
+ voice_tensors = []
+ for v, w in zip(voice_parts, weights):
+ path = await self._voice_manager.get_voice_path(v)
if not path:
- raise RuntimeError(f"Voice not found: {voice}")
- logger.debug(f"Using single voice path: {path}")
- return voice, path
-
- total_weight = 0
-
- for voice_index in range(0, len(split_voice), 2):
- voice_object = split_voice[voice_index]
-
- if "(" in voice_object and ")" in voice_object:
- voice_name = voice_object.split("(")[0].strip()
- voice_weight = float(voice_object.split("(")[1].split(")")[0])
- else:
- voice_name = voice_object
- voice_weight = 1
-
- total_weight += voice_weight
- split_voice[voice_index] = (voice_name, voice_weight)
-
- # If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
- if settings.voice_weight_normalization == False:
- total_weight = 1
-
- # Load the first voice as the starting point for voices to be combined onto
- path = await self._voice_manager.get_voice_path(split_voice[0][0])
- combined_tensor = await self._load_voice_from_path(
- path, split_voice[0][1] / total_weight
- )
-
- # Loop through each + or - in split_voice so they can be applied to combined voice
- for operation_index in range(1, len(split_voice) - 1, 2):
- # Get the voice path of the voice 1 index ahead of the operator
- path = await self._voice_manager.get_voice_path(
- split_voice[operation_index + 1][0]
- )
- voice_tensor = await self._load_voice_from_path(
- path, split_voice[operation_index + 1][1] / total_weight
+ raise RuntimeError(f"Voice not found: {v}")
+ logger.debug(f"Loading voice tensor from: {path}")
+ voice_tensor = torch.load(path, map_location="cpu")
+ voice_tensors.append(voice_tensor * w)
+
+ # Sum the weighted voice tensors
+ logger.debug(
+ f"Combining {len(voice_tensors)} voice tensors with weights {weights}"
)
+ combined = torch.sum(torch.stack(voice_tensors), dim=0)
- # Either add or subtract the voice from the current combined voice
- if split_voice[operation_index] == "+":
- combined_tensor += voice_tensor
- else:
- combined_tensor -= voice_tensor
-
- # Save the new combined voice so it can be loaded latter
- temp_dir = tempfile.gettempdir()
- combined_path = os.path.join(temp_dir, f"{voice}.pt")
- logger.debug(f"Saving combined voice to: {combined_path}")
- torch.save(combined_tensor, combined_path)
- return voice, combined_path
+ # Save combined tensor
+ temp_dir = tempfile.gettempdir()
+ combined_path = os.path.join(temp_dir, f"{voice}.pt")
+ logger.debug(f"Saving combined voice to: {combined_path}")
+ torch.save(combined, combined_path)
+
+ return voice, combined_path
+ else:
+ # Single voice
+ if "(" in voice and ")" in voice:
+ voice = voice.split("(")[0].strip()
+ path = await self._voice_manager.get_voice_path(voice)
+ if not path:
+ raise RuntimeError(f"Voice not found: {voice}")
+ logger.debug(f"Using single voice path: {path}")
+ return voice, path
except Exception as e:
logger.error(f"Failed to get voice path: {e}")
raise
@@ -259,24 +242,22 @@ async def generate_audio_stream(
self,
text: str,
voice: str,
- writer: StreamingAudioWriter,
speed: float = 1.0,
output_format: str = "wav",
lang_code: Optional[str] = None,
- volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
chunk_index = 0
- current_offset = 0.0
+ current_offset=0.0
try:
# Get backend
backend = self.model_manager.get_backend()
# Get voice path, handling combined voices
- voice_name, voice_path = await self._get_voices_path(voice)
+ voice_name, voice_path = await self._get_voice_path(voice)
logger.debug(f"Using voice path: {voice_path}")
# Use provided lang_code or determine from voice name
@@ -284,91 +265,45 @@ async def generate_audio_stream(
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
)
-
- # Process text in chunks with smart splitting, handling pause tags
- async for chunk_text, tokens, pause_duration_s in smart_split(
- text,
- lang_code=pipeline_lang_code,
- normalization_options=normalization_options,
- ):
- if pause_duration_s is not None and pause_duration_s > 0:
- # --- Handle Pause Chunk ---
- try:
- logger.debug(f"Generating {pause_duration_s}s silence chunk")
- silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
- # Create proper silence as int16 zeros to avoid normalization artifacts
- silence_audio = np.zeros(silence_samples, dtype=np.int16)
- pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
-
- # Format and yield the silence chunk
- if output_format:
- formatted_pause_chunk = await AudioService.convert_audio(
- pause_chunk, output_format, writer, speed=speed, chunk_text="",
- is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
-
+
+
+ # Process text in chunks with smart splitting
+ async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
+ try:
+ # Process audio for chunk
+ async for chunk_data in self._process_chunk(
+ chunk_text, # Pass text for Kokoro V1
+ tokens, # Pass tokens for legacy backends
+ voice_name, # Pass voice name
+ voice_path, # Pass voice path
+ speed,
+ output_format,
+ is_first=(chunk_index == 0),
+ is_last=False, # We'll update the last chunk later
+ normalizer=stream_normalizer,
+ lang_code=pipeline_lang_code, # Pass lang_code
+ return_timestamps=return_timestamps,
+ ):
+ if chunk_data.word_timestamps is not None:
+ for timestamp in chunk_data.word_timestamps:
+ timestamp.start_time+=current_offset
+ timestamp.end_time+=current_offset
+
+ current_offset+=len(chunk_data.audio) / 24000
+
+ if chunk_data.output is not None:
+ yield chunk_data
+
+ else:
+ logger.warning(
+ f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
- if formatted_pause_chunk.output:
- yield formatted_pause_chunk
- else: # Raw audio mode
- # For raw audio mode, silence is already in the correct format (int16)
- # Skip normalization to avoid any potential artifacts
- if len(pause_chunk.audio) > 0:
- yield pause_chunk
-
- # Update offset based on silence duration
- current_offset += pause_duration_s
- chunk_index += 1 # Count pause as a yielded chunk
-
- except Exception as e:
- logger.error(f"Failed to process pause chunk: {str(e)}")
- continue
-
- elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
- # --- Handle Text Chunk ---
- try:
- # Process audio for chunk
- async for chunk_data in self._process_chunk(
- chunk_text, # Pass text for Kokoro V1
- tokens, # Pass tokens for legacy backends
- voice_name, # Pass voice name
- voice_path, # Pass voice path
- speed,
- writer,
- output_format,
- is_first=(chunk_index == 0),
- volume_multiplier=volume_multiplier,
- is_last=False, # We'll update the last chunk later
- normalizer=stream_normalizer,
- lang_code=pipeline_lang_code, # Pass lang_code
- return_timestamps=return_timestamps,
- ):
- if chunk_data.word_timestamps is not None:
- for timestamp in chunk_data.word_timestamps:
- timestamp.start_time += current_offset
- timestamp.end_time += current_offset
-
- # Update offset based on the actual duration of the generated audio chunk
- chunk_duration = 0
- if chunk_data.audio is not None and len(chunk_data.audio) > 0:
- chunk_duration = len(chunk_data.audio) / 24000
- current_offset += chunk_duration
-
- # Yield the processed chunk (either formatted or raw)
- if chunk_data.output is not None:
- yield chunk_data
- elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
- yield chunk_data
- else:
- logger.warning(
- f"No audio generated for chunk: '{chunk_text[:100]}...'"
- )
-
- chunk_index += 1 # Increment chunk index after processing text
- except Exception as e:
- logger.error(
- f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
- )
- continue
+ chunk_index += 1
+ except Exception as e:
+ logger.error(
+ f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
+ )
+ continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
@@ -380,11 +315,9 @@ async def generate_audio_stream(
voice_name,
voice_path,
speed,
- writer,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
- volume_multiplier=volume_multiplier,
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
@@ -396,41 +329,32 @@ async def generate_audio_stream(
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise e
+
async def generate_audio(
self,
text: str,
voice: str,
- writer: StreamingAudioWriter,
speed: float = 1.0,
return_timestamps: bool = False,
- volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None,
) -> AudioChunk:
"""Generate complete audio for text using streaming internally."""
- audio_data_chunks = []
-
+ audio_data_chunks=[]
+
try:
- async for audio_stream_data in self.generate_audio_stream(
- text,
- voice,
- writer,
- speed=speed,
- volume_multiplier=volume_multiplier,
- normalization_options=normalization_options,
- return_timestamps=return_timestamps,
- lang_code=lang_code,
- output_format=None,
- ):
+ async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
- combined_audio_data = AudioChunk.combine(audio_data_chunks)
+
+ combined_audio_data=AudioChunk.combine(audio_data_chunks)
return combined_audio_data
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
+
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
"""Combine multiple voices.
@@ -438,7 +362,6 @@ async def combine_voices(self, voices: List[str]) -> torch.Tensor:
Returns:
Combined voice tensor
"""
-
return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
@@ -467,7 +390,7 @@ async def generate_from_phonemes(
try:
# Get backend and voice path
backend = self.model_manager.get_backend()
- voice_name, voice_path = await self._get_voices_path(voice)
+ voice_name, voice_path = await self._get_voice_path(voice)
if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes
diff --git a/api/src/structures/custom_responses.py b/api/src/structures/custom_responses.py
index 0f838829..a996e5dc 100644
--- a/api/src/structures/custom_responses.py
+++ b/api/src/structures/custom_responses.py
@@ -1,7 +1,7 @@
-import json
-import typing
from collections.abc import AsyncIterable, Iterable
+import json
+import typing
from pydantic import BaseModel
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
@@ -24,27 +24,28 @@ def __init__(
else:
self._content_iterable = iterate_in_threadpool(content)
+
+
async def body_iterator() -> AsyncIterable[bytes]:
async for content_ in self._content_iterable:
if isinstance(content_, BaseModel):
content_ = content_.model_dump()
yield self.render(content_)
-
+
+
+
self.body_iterator = body_iterator()
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.init_headers(headers)
-
+
def render(self, content: typing.Any) -> bytes:
- return (
- json.dumps(
+ return (json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
- )
- + "\n"
- ).encode("utf-8")
+ ) + "\n").encode("utf-8")
\ No newline at end of file
diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py
index 0aeab7b7..fa65b4d8 100644
--- a/api/src/structures/schemas.py
+++ b/api/src/structures/schemas.py
@@ -1,4 +1,3 @@
-from email.policy import default
from enum import Enum
from typing import List, Literal, Optional, Union
@@ -36,43 +35,17 @@ class CaptionedSpeechResponse(BaseModel):
audio: str = Field(..., description="The generated audio data encoded in base 64")
audio_format: str = Field(..., description="The format of the output audio")
- timestamps: Optional[List[WordTimestamp]] = Field(
- ..., description="Word-level timestamps"
- )
-
+ timestamps: Optional[List[WordTimestamp]] = Field(..., description="Word-level timestamps")
class NormalizationOptions(BaseModel):
"""Options for the normalization system"""
-
- normalize: bool = Field(
- default=True,
- description="Normalizes input text to make it easier for the model to say",
- )
- unit_normalization: bool = Field(
- default=False, description="Transforms units like 10KB to 10 kilobytes"
- )
- url_normalization: bool = Field(
- default=True,
- description="Changes urls so they can be properly pronounced by kokoro",
- )
- email_normalization: bool = Field(
- default=True,
- description="Changes emails so they can be properly pronouced by kokoro",
- )
- optional_pluralization_normalization: bool = Field(
- default=True,
- description="Replaces (s) with s so some words get pronounced correctly",
- )
- phone_normalization: bool = Field(
- default=True,
- description="Changes phone numbers so they can be properly pronouced by kokoro",
- )
- replace_remaining_symbols: bool = Field(
- default=True,
- description="Replaces the remaining symbols after normalization with their words"
- )
-
-
+ normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
+ unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
+ url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
+ email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
+ optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
+ phone_normalization: bool = Field(default=True, description="Changes phone numbers so they can be properly pronouced by kokoro")
+
class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint"""
@@ -89,11 +62,9 @@ class OpenAISpeechRequest(BaseModel):
default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
)
- download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
- Field(
- default=None,
- description="Optional different format for the final download. If not provided, uses response_format.",
- )
+ download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = Field(
+ default=None,
+ description="Optional different format for the final download. If not provided, uses response_format.",
)
speed: float = Field(
default=1.0,
@@ -113,13 +84,9 @@ class OpenAISpeechRequest(BaseModel):
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
- volume_multiplier: Optional[float] = Field(
- default = 1.0,
- description="A volume multiplier to multiply the output audio by."
- )
normalization_options: Optional[NormalizationOptions] = Field(
- default=NormalizationOptions(),
- description="Options for the normalization system",
+ default= NormalizationOptions(),
+ description= "Options for the normalization system"
)
@@ -161,11 +128,7 @@ class CaptionedSpeechRequest(BaseModel):
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
- volume_multiplier: Optional[float] = Field(
- default = 1.0,
- description="A volume multiplier to multiply the output audio by."
- )
normalization_options: Optional[NormalizationOptions] = Field(
- default=NormalizationOptions(),
- description="Options for the normalization system",
+ default= NormalizationOptions(),
+ description= "Options for the normalization system"
)
diff --git a/api/tests/conftest.py b/api/tests/conftest.py
index 2e3bba8a..b8dd7613 100644
--- a/api/tests/conftest.py
+++ b/api/tests/conftest.py
@@ -69,3 +69,17 @@ async def tts_service(mock_model_manager, mock_voice_manager):
def test_voice():
"""Return a test voice name."""
return "voice1"
+
+
+@pytest.fixture(scope="session")
+def event_loop():
+ """Create an instance of the default event loop for the test session."""
+ import asyncio
+
+ try:
+ loop = asyncio.get_event_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ yield loop
+ loop.close()
diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py
index 5ba53928..ca2d25da 100644
--- a/api/tests/test_audio_service.py
+++ b/api/tests/test_audio_service.py
@@ -5,10 +5,8 @@
import numpy as np
import pytest
-from api.src.inference.base import AudioChunk
from api.src.services.audio import AudioNormalizer, AudioService
-from api.src.services.streaming_audio_writer import StreamingAudioWriter
-
+from api.src.inference.base import AudioChunk
@pytest.fixture(autouse=True)
def mock_settings():
@@ -32,15 +30,10 @@ def sample_audio():
async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("wav", sample_rate=24000)
# Write and finalize in one step for WAV
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "wav", writer, is_last_chunk=False
+ AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=False
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@@ -53,38 +46,23 @@ async def test_convert_to_wav(sample_audio):
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("mp3", sample_rate=24000)
-
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "mp3", writer
+ AudioChunk(audio_data), sample_rate, "mp3"
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check MP3 header (ID3 or MPEG frame sync)
- assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
- b"\xff\xfb"
- )
+ assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
@pytest.mark.asyncio
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
-
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("opus", sample_rate=24000)
-
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "opus", writer
+ AudioChunk(audio_data), sample_rate, "opus"
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@@ -96,15 +74,9 @@ async def test_convert_to_opus(sample_audio):
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("flac", sample_rate=24000)
-
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "flac", writer
+ AudioChunk(audio_data), sample_rate, "flac"
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@@ -116,37 +88,23 @@ async def test_convert_to_flac(sample_audio):
async def test_convert_to_aac(sample_audio):
"""Test converting to M4A format"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("aac", sample_rate=24000)
-
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "aac", writer
+ AudioChunk(audio_data), sample_rate, "aac"
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check ADTS header (AAC)
- assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
- b"\xff\xf1"
- )
+ assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
@pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("pcm", sample_rate=24000)
-
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "pcm", writer
+ AudioChunk(audio_data), sample_rate, "pcm"
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@@ -156,27 +114,21 @@ async def test_convert_to_pcm(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
- # audio_data, sample_rate = sample_audio
- with pytest.raises(ValueError, match="Unsupported format: invalid"):
- writer = StreamingAudioWriter("invalid", sample_rate=24000)
+ audio_data, sample_rate = sample_audio
+ with pytest.raises(ValueError, match="Format invalid not supported"):
+ await AudioService.convert_audio(audio_data, sample_rate, "invalid")
@pytest.mark.asyncio
async def test_normalization_wav(sample_audio):
"""Test that WAV output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("wav", sample_rate=24000)
-
# Create audio data outside int16 range
large_audio = audio_data * 1e5
# Write and finalize in one step for WAV
audio_chunk = await AudioService.convert_audio(
- AudioChunk(large_audio), "wav", writer
+ AudioChunk(large_audio), sample_rate, "wav", is_first_chunk=True
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@@ -186,13 +138,10 @@ async def test_normalization_wav(sample_audio):
async def test_normalization_pcm(sample_audio):
"""Test that PCM output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("pcm", sample_rate=24000)
-
# Create audio data outside int16 range
large_audio = audio_data * 1e5
audio_chunk = await AudioService.convert_audio(
- AudioChunk(large_audio), "pcm", writer
+ AudioChunk(large_audio), sample_rate, "pcm"
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
@@ -204,11 +153,8 @@ async def test_invalid_audio_data():
"""Test handling of invalid audio data"""
invalid_audio = np.array([]) # Empty array
sample_rate = 24000
-
- writer = StreamingAudioWriter("wav", sample_rate=24000)
-
with pytest.raises(ValueError):
- await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
+ await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
@pytest.mark.asyncio
@@ -218,14 +164,9 @@ async def test_different_sample_rates(sample_audio):
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
- writer = StreamingAudioWriter("wav", sample_rate=rate)
-
audio_chunk = await AudioService.convert_audio(
- AudioChunk(audio_data), "wav", writer
+ AudioChunk(audio_data), rate, "wav", is_first_chunk=True
)
-
- writer.close()
-
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@@ -235,21 +176,15 @@ async def test_different_sample_rates(sample_audio):
async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio
-
- writer = StreamingAudioWriter("wav", sample_rate=24000)
-
# Write and finalize in one step for first conversion
audio_chunk1 = await AudioService.convert_audio(
- AudioChunk(audio_data), "wav", writer, is_last_chunk=True
+ AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(audio_chunk1.output, bytes)
assert isinstance(audio_chunk1, AudioChunk)
# Convert again to ensure buffer was properly reset
-
- writer = StreamingAudioWriter("wav", sample_rate=24000)
-
audio_chunk2 = await AudioService.convert_audio(
- AudioChunk(audio_data), "wav", writer, is_last_chunk=True
+ AudioChunk(audio_data), sample_rate, "wav", is_first_chunk=True, is_last_chunk=True
)
assert isinstance(audio_chunk2.output, bytes)
assert isinstance(audio_chunk2, AudioChunk)
diff --git a/api/tests/test_development.py b/api/tests/test_development.py
index a03b3baa..4760347d 100644
--- a/api/tests/test_development.py
+++ b/api/tests/test_development.py
@@ -1,10 +1,8 @@
-import base64
-import json
-from unittest.mock import MagicMock, patch
-
import pytest
+from unittest.mock import patch, MagicMock
import requests
-
+import base64
+import json
def test_generate_captioned_speech():
"""Test the generate_captioned_speech function with mocked responses"""
@@ -14,21 +12,20 @@ def test_generate_captioned_speech():
mock_timestamps_response = MagicMock()
mock_timestamps_response.status_code = 200
- mock_timestamps_response.content = json.dumps(
- {
- "audio": base64.b64encode(b"mock audio data").decode("utf-8"),
- "timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
- }
- )
+ mock_timestamps_response.content = json.dumps({
+ "audio":base64.b64encode(b"mock audio data").decode("utf-8"),
+ "timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
+ })
# Patch the HTTP requests
- with patch("requests.post", return_value=mock_timestamps_response):
+ with patch('requests.post', return_value=mock_timestamps_response):
+
# Import here to avoid module-level import issues
from examples.captioned_speech_example import generate_captioned_speech
-
+
# Test the function
audio, timestamps = generate_captioned_speech("test text")
-
+
# Verify we got both audio and timestamps
assert audio == b"mock audio data"
- assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
+ assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
\ No newline at end of file
diff --git a/api/tests/test_kokoro_v1.py b/api/tests/test_kokoro_v1.py
index 29d83c5a..850ed05e 100644
--- a/api/tests/test_kokoro_v1.py
+++ b/api/tests/test_kokoro_v1.py
@@ -23,18 +23,19 @@ def test_initial_state(kokoro_backend):
@patch("torch.cuda.is_available", return_value=True)
-@patch("torch.cuda.memory_allocated", return_value=5e9)
+@patch("torch.cuda.memory_allocated")
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
"""Test GPU memory management functions."""
- # Patch backend so it thinks we have cuda
- with patch.object(kokoro_backend, "_device", "cuda"):
- # Test memory check
- with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
- mock_config.pytorch_gpu.memory_threshold = 4
- assert kokoro_backend._check_memory() == True
+ # Mock GPU memory usage
+ mock_memory.return_value = 5e9 # 5GB
+
+ # Test memory check
+ with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
+ mock_config.pytorch_gpu.memory_threshold = 4
+ assert kokoro_backend._check_memory() == True
- mock_config.pytorch_gpu.memory_threshold = 6
- assert kokoro_backend._check_memory() == False
+ mock_config.pytorch_gpu.memory_threshold = 6
+ assert kokoro_backend._check_memory() == False
@patch("torch.cuda.empty_cache")
diff --git a/api/tests/test_normalizer.py b/api/tests/test_normalizer.py
index 6b5a8bfb..0aa963ee 100644
--- a/api/tests/test_normalizer.py
+++ b/api/tests/test_normalizer.py
@@ -5,48 +5,27 @@
from api.src.services.text_processing.normalizer import normalize_text
from api.src.structures.schemas import NormalizationOptions
-
def test_url_protocols():
"""Test URL protocol handling"""
assert (
- normalize_text(
- "Check out https://example.com",
- normalization_options=NormalizationOptions(),
- )
+ normalize_text("Check out https://example.com",normalization_options=NormalizationOptions())
== "Check out https example dot com"
)
+ assert normalize_text("Visit http://site.com",normalization_options=NormalizationOptions()) == "Visit http site dot com"
assert (
- normalize_text(
- "Visit http://site.com", normalization_options=NormalizationOptions()
- )
- == "Visit http site dot com"
- )
- assert (
- normalize_text(
- "Go to https://test.org/path", normalization_options=NormalizationOptions()
- )
+ normalize_text("Go to https://test.org/path",normalization_options=NormalizationOptions())
== "Go to https test dot org slash path"
)
def test_url_www():
"""Test www prefix handling"""
+ assert normalize_text("Go to www.example.com",normalization_options=NormalizationOptions()) == "Go to www example dot com"
assert (
- normalize_text(
- "Go to www.example.com", normalization_options=NormalizationOptions()
- )
- == "Go to www example dot com"
+ normalize_text("Visit www.test.org/docs",normalization_options=NormalizationOptions()) == "Visit www test dot org slash docs"
)
assert (
- normalize_text(
- "Visit www.test.org/docs", normalization_options=NormalizationOptions()
- )
- == "Visit www test dot org slash docs"
- )
- assert (
- normalize_text(
- "Check www.site.com?q=test", normalization_options=NormalizationOptions()
- )
+ normalize_text("Check www.site.com?q=test",normalization_options=NormalizationOptions())
== "Check www site dot com question-mark q equals test"
)
@@ -54,280 +33,59 @@ def test_url_www():
def test_url_localhost():
"""Test localhost URL handling"""
assert (
- normalize_text(
- "Running on localhost:7860", normalization_options=NormalizationOptions()
- )
- == "Running on localhost colon seventy-eight sixty"
+ normalize_text("Running on localhost:7860",normalization_options=NormalizationOptions())
+ == "Running on localhost colon 78 60"
)
assert (
- normalize_text(
- "Server at localhost:8080/api", normalization_options=NormalizationOptions()
- )
- == "Server at localhost colon eighty eighty slash api"
+ normalize_text("Server at localhost:8080/api",normalization_options=NormalizationOptions())
+ == "Server at localhost colon 80 80 slash api"
)
assert (
- normalize_text(
- "Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
- )
- == "Test localhost colon three thousand slash test question-mark v equals one"
+ normalize_text("Test localhost:3000/test?v=1",normalization_options=NormalizationOptions())
+ == "Test localhost colon 3000 slash test question-mark v equals 1"
)
def test_url_ip_addresses():
"""Test IP address URL handling"""
assert (
- normalize_text(
- "Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
- )
- == "Access zero dot zero dot zero dot zero colon ninety ninety slash test"
+ normalize_text("Access 0.0.0.0:9090/test",normalization_options=NormalizationOptions())
+ == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
)
assert (
- normalize_text(
- "API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
- )
- == "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand"
- )
- assert (
- normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
- == "Server one hundred and twenty-seven dot zero dot zero dot one"
+ normalize_text("API at 192.168.1.1:8000",normalization_options=NormalizationOptions())
+ == "API at 192 dot 168 dot 1 dot 1 colon 8000"
)
+ assert normalize_text("Server 127.0.0.1",normalization_options=NormalizationOptions()) == "Server 127 dot 0 dot 0 dot 1"
def test_url_raw_domains():
"""Test raw domain handling"""
assert (
- normalize_text(
- "Visit google.com/search", normalization_options=NormalizationOptions()
- )
- == "Visit google dot com slash search"
+ normalize_text("Visit google.com/search",normalization_options=NormalizationOptions()) == "Visit google dot com slash search"
)
assert (
- normalize_text(
- "Go to example.com/path?q=test",
- normalization_options=NormalizationOptions(),
- )
+ normalize_text("Go to example.com/path?q=test",normalization_options=NormalizationOptions())
== "Go to example dot com slash path question-mark q equals test"
)
- assert (
- normalize_text(
- "Check docs.test.com", normalization_options=NormalizationOptions()
- )
- == "Check docs dot test dot com"
- )
+ assert normalize_text("Check docs.test.com",normalization_options=NormalizationOptions()) == "Check docs dot test dot com"
def test_url_email_addresses():
"""Test email address handling"""
assert (
- normalize_text(
- "Email me at user@example.com", normalization_options=NormalizationOptions()
- )
+ normalize_text("Email me at user@example.com",normalization_options=NormalizationOptions())
== "Email me at user at example dot com"
)
+ assert normalize_text("Contact admin@test.org",normalization_options=NormalizationOptions()) == "Contact admin at test dot org"
assert (
- normalize_text(
- "Contact admin@test.org", normalization_options=NormalizationOptions()
- )
- == "Contact admin at test dot org"
- )
- assert (
- normalize_text(
- "Send to test.user@site.com", normalization_options=NormalizationOptions()
- )
+ normalize_text("Send to test.user@site.com",normalization_options=NormalizationOptions())
== "Send to test dot user at site dot com"
)
-def test_money():
- """Test that money text is normalized correctly"""
- assert (
- normalize_text(
- "He lost $5.3 thousand.", normalization_options=NormalizationOptions()
- )
- == "He lost five point three thousand dollars."
- )
-
- assert (
- normalize_text(
- "He went gambling and lost about $25.05k.",
- normalization_options=NormalizationOptions(),
- )
- == "He went gambling and lost about twenty-five point zero five thousand dollars."
- )
-
- assert (
- normalize_text(
- "To put it weirdly -$6.9 million",
- normalization_options=NormalizationOptions(),
- )
- == "To put it weirdly minus six point nine million dollars"
- )
-
- assert (
- normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
- == "It costs fifty dollars and thirty cents."
- )
-
- assert (
- normalize_text(
- "The plant cost $200,000.8.", normalization_options=NormalizationOptions()
- )
- == "The plant cost two hundred thousand dollars and eighty cents."
- )
-
- assert (
- normalize_text(
- "Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
- )
- == "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
- )
-
- assert (
- normalize_text(
- "€30.2 is in euros", normalization_options=NormalizationOptions()
- )
- == "thirty euros and twenty cents is in euros"
- )
-
-
-def test_time():
- """Test time normalization"""
-
- assert (
- normalize_text(
- "Your flight leaves at 10:35 pm",
- normalization_options=NormalizationOptions(),
- )
- == "Your flight leaves at ten thirty-five pm"
- )
-
- assert (
- normalize_text(
- "He departed for london around 5:03 am.",
- normalization_options=NormalizationOptions(),
- )
- == "He departed for london around five oh three am."
- )
-
- assert (
- normalize_text(
- "Only the 13:42 and 15:12 slots are available.",
- normalization_options=NormalizationOptions(),
- )
- == "Only the thirteen forty-two and fifteen twelve slots are available."
- )
-
- assert (
- normalize_text(
- "It is currently 1:00 pm", normalization_options=NormalizationOptions()
- )
- == "It is currently one pm"
- )
-
- assert (
- normalize_text(
- "It is currently 3:00", normalization_options=NormalizationOptions()
- )
- == "It is currently three o'clock"
- )
-
- assert (
- normalize_text(
- "12:00 am is midnight", normalization_options=NormalizationOptions()
- )
- == "twelve am is midnight"
- )
-
-
-def test_number():
- """Test number normalization"""
-
- assert (
- normalize_text(
- "I bought 1035 cans of soda", normalization_options=NormalizationOptions()
- )
- == "I bought one thousand and thirty-five cans of soda"
- )
-
- assert (
- normalize_text(
- "The bus has a maximum capacity of 62 people",
- normalization_options=NormalizationOptions(),
- )
- == "The bus has a maximum capacity of sixty-two people"
- )
-
- assert (
- normalize_text(
- "There are 1300 products left in stock",
- normalization_options=NormalizationOptions(),
- )
- == "There are one thousand, three hundred products left in stock"
- )
-
- assert (
- normalize_text(
- "The population is 7,890,000 people.",
- normalization_options=NormalizationOptions(),
- )
- == "The population is seven million, eight hundred and ninety thousand people."
- )
-
- assert (
- normalize_text(
- "He looked around but only found 1.6k of the 10k bricks",
- normalization_options=NormalizationOptions(),
- )
- == "He looked around but only found one point six thousand of the ten thousand bricks"
- )
-
- assert (
- normalize_text(
- "The book has 342 pages.", normalization_options=NormalizationOptions()
- )
- == "The book has three hundred and forty-two pages."
- )
-
- assert (
- normalize_text(
- "He made -50 sales today.", normalization_options=NormalizationOptions()
- )
- == "He made minus fifty sales today."
- )
-
- assert (
- normalize_text(
- "56.789 to the power of 1.35 million",
- normalization_options=NormalizationOptions(),
- )
- == "fifty-six point seven eight nine to the power of one point three five million"
- )
-
-
def test_non_url_text():
"""Test that non-URL text is unaffected"""
- assert (
- normalize_text(
- "This is not.a.url text", normalization_options=NormalizationOptions()
- )
- == "This is not-a-url text"
- )
- assert (
- normalize_text(
- "Hello, how are you today?", normalization_options=NormalizationOptions()
- )
- == "Hello, how are you today?"
- )
- assert (
- normalize_text("It costs $50.", normalization_options=NormalizationOptions())
- == "It costs fifty dollars."
- )
-
-def test_remaining_symbol():
- """Test that remaining symbols are replaced"""
- assert (
- normalize_text(
- "I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
- )
- == "I love buying products at good store here and at other store"
- )
+ assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
+ assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
+ assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs fifty dollars."
diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py
index d5c7efcb..527cb1fe 100644
--- a/api/tests/test_openai_endpoints.py
+++ b/api/tests/test_openai_endpoints.py
@@ -4,19 +4,18 @@
from typing import AsyncGenerator, Tuple
from unittest.mock import AsyncMock, MagicMock, patch
+from api.src.inference.base import AudioChunk
import numpy as np
import pytest
from fastapi.testclient import TestClient
from api.src.core.config import settings
-from api.src.inference.base import AudioChunk
from api.src.main import app
from api.src.routers.openai_compatible import (
get_tts_service,
load_openai_mappings,
stream_audio_chunks,
)
-from api.src.services.streaming_audio_writer import StreamingAudioWriter
from api.src.services.tts_service import TTSService
from api.src.structures.schemas import OpenAISpeechRequest
@@ -79,13 +78,13 @@ def test_list_models(mock_openai_mappings):
assert data["object"] == "list"
assert isinstance(data["data"], list)
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
-
+
# Verify all expected models are present
model_ids = [model["id"] for model in data["data"]]
assert "tts-1" in model_ids
assert "tts-1-hd" in model_ids
assert "kokoro" in model_ids
-
+
# Verify model format
for model in data["data"]:
assert model["object"] == "model"
@@ -113,6 +112,7 @@ def test_retrieve_model(mock_openai_mappings):
assert error["detail"]["type"] == "invalid_request_error"
+
@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
@@ -145,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
async def mock_stream(*args, **kwargs):
for i in range(5):
- yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
+ yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
@@ -159,14 +159,10 @@ async def mock_stream(*args, **kwargs):
speed=1.0,
)
- writer = StreamingAudioWriter("mp3", 24000)
-
chunks = []
- async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
+ async for chunk in stream_audio_chunks(mock_service, request, mock_request):
chunks.append(chunk)
- writer.close()
-
assert len(chunks) == 0 # Should stop immediately due to disconnect
@@ -241,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
- service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
+ service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
- yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
+ yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
@@ -261,10 +257,8 @@ def test_openai_speech_endpoint(
):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
- mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
- mock_convert.return_value = AudioChunk(
- np.zeros(1000, np.int16), output=mock_audio_bytes
- )
+ mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
+ mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
response = client.post(
"/v1/audio/speech",
@@ -489,11 +483,7 @@ async def mock_error_stream(*args, **kwargs):
speed=1.0,
)
- writer = StreamingAudioWriter("mp3", 24000)
-
with pytest.raises(RuntimeError) as exc:
- async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
+ async for _ in stream_audio_chunks(mock_service, request, MagicMock()):
pass
-
- writer.close()
assert "Failed to initialize stream" in str(exc.value)
diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py
index 3fc8a87c..3d844b15 100644
--- a/api/tests/test_text_processor.py
+++ b/api/tests/test_text_processor.py
@@ -44,45 +44,19 @@ def test_get_sentence_info():
assert count == len(tokens)
assert count > 0
+
@pytest.mark.asyncio
async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens."""
text = "This is a short test sentence."
chunks = []
- async for chunk_text, chunk_tokens, _ in smart_split(text):
+ async for chunk_text, chunk_tokens in smart_split(text):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
-@pytest.mark.asyncio
-async def test_smart_custom_phenomes():
- """Test smart splitting with text under max tokens."""
- text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses"
- chunks = []
- async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
- chunks.append((chunk_text, chunk_tokens, pause_duration))
-
- # Should have 1 chunks: text
- assert len(chunks) == 1
-
- # First chunk: text
- assert chunks[0][2] is None # No pause
- assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0]
- assert len(chunks[0][1]) > 0
-
-@pytest.mark.asyncio
-async def test_smart_split_only_phenomes():
- """Test input that is entirely made of phenome annotations."""
- text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)"
- chunks = []
- async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10):
- chunks.append((chunk_text, chunk_tokens, pause_duration))
-
- assert len(chunks) == 1
- assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0]
-
@pytest.mark.asyncio
async def test_smart_split_long_text():
@@ -91,7 +65,7 @@ async def test_smart_split_long_text():
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = []
- async for chunk_text, chunk_tokens, _ in smart_split(text):
+ async for chunk_text, chunk_tokens in smart_split(text):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
@@ -107,127 +81,8 @@ async def test_smart_split_with_punctuation():
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = []
- async for chunk_text, chunk_tokens, _ in smart_split(text):
+ async for chunk_text, chunk_tokens in smart_split(text):
chunks.append(chunk_text)
# Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
-
-
-def test_process_text_chunk_chinese_phonemes():
- """Test processing with Chinese pinyin phonemes."""
- pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
- tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
- assert isinstance(tokens, list)
- assert len(tokens) > 0
-
-
-def test_get_sentence_info_chinese():
- """Test Chinese sentence splitting and info extraction."""
- text = "这是一个句子。这是第二个句子!第三个问题?"
- results = get_sentence_info(text, lang_code="z")
-
- assert len(results) == 3
- for sentence, tokens, count in results:
- assert isinstance(sentence, str)
- assert isinstance(tokens, list)
- assert isinstance(count, int)
- assert count == len(tokens)
- assert count > 0
-
-
-@pytest.mark.asyncio
-async def test_smart_split_chinese_short():
- """Test Chinese smart splitting with short text."""
- text = "这是一句话。"
- chunks = []
- async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
- chunks.append((chunk_text, chunk_tokens))
-
- assert len(chunks) == 1
- assert isinstance(chunks[0][0], str)
- assert isinstance(chunks[0][1], list)
-
-
-@pytest.mark.asyncio
-async def test_smart_split_chinese_long():
- """Test Chinese smart splitting with longer text."""
- text = "。".join([f"测试句子 {i}" for i in range(20)])
-
- chunks = []
- async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
- chunks.append((chunk_text, chunk_tokens))
-
- assert len(chunks) > 1
- for chunk_text, chunk_tokens in chunks:
- assert isinstance(chunk_text, str)
- assert isinstance(chunk_tokens, list)
- assert len(chunk_tokens) > 0
-
-
-@pytest.mark.asyncio
-async def test_smart_split_chinese_punctuation():
- """Test Chinese smart splitting with punctuation preservation."""
- text = "第一句!第二问?第三句;第四句:第五句。"
-
- chunks = []
- async for chunk_text, _, _ in smart_split(text, lang_code="z"):
- chunks.append(chunk_text)
-
- # Verify Chinese punctuation is preserved
- assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
-
-
-@pytest.mark.asyncio
-async def test_smart_split_with_pause():
- """Test smart splitting with pause tags."""
- text = "Hello world [pause:2.5s] How are you?"
-
- chunks = []
- async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
- chunks.append((chunk_text, chunk_tokens, pause_duration))
-
- # Should have 3 chunks: text, pause, text
- assert len(chunks) == 3
-
- # First chunk: text
- assert chunks[0][2] is None # No pause
- assert "Hello world" in chunks[0][0]
- assert len(chunks[0][1]) > 0
-
- # Second chunk: pause
- assert chunks[1][2] == 2.5 # 2.5 second pause
- assert chunks[1][0] == "" # Empty text
- assert len(chunks[1][1]) == 0 # No tokens
-
- # Third chunk: text
- assert chunks[2][2] is None # No pause
- assert "How are you?" in chunks[2][0]
- assert len(chunks[2][1]) > 0
-
-@pytest.mark.asyncio
-async def test_smart_split_with_two_pause():
- """Test smart splitting with two pause tags."""
- text = "[pause:0.5s][pause:1.67s]0.5"
-
- chunks = []
- async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
- chunks.append((chunk_text, chunk_tokens, pause_duration))
-
- # Should have 3 chunks: pause, pause, text
- assert len(chunks) == 3
-
- # First chunk: pause
- assert chunks[0][2] == 0.5 # 0.5 second pause
- assert chunks[0][0] == "" # Empty text
- assert len(chunks[0][1]) == 0
-
- # Second chunk: pause
- assert chunks[1][2] == 1.67 # 1.67 second pause
- assert chunks[1][0] == "" # Empty text
- assert len(chunks[1][1]) == 0 # No tokens
-
- # Third chunk: text
- assert chunks[2][2] is None # No pause
- assert "zero point five" in chunks[2][0]
- assert len(chunks[2][1]) > 0
\ No newline at end of file
diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py
index 968c31f0..0d458240 100644
--- a/api/tests/test_tts_service.py
+++ b/api/tests/test_tts_service.py
@@ -3,7 +3,6 @@
import numpy as np
import pytest
import torch
-import os
from api.src.services.tts_service import TTSService
@@ -75,7 +74,7 @@ async def test_get_voice_path_single():
mock_get_voice.return_value = voice_manager
service = await TTSService.create("test_output")
- name, path = await service._get_voices_path("voice1")
+ name, path = await service._get_voice_path("voice1")
assert name == "voice1"
assert path == "/path/to/voice1.pt"
voice_manager.get_voice_path.assert_called_once_with("voice1")
@@ -101,12 +100,9 @@ async def test_get_voice_path_combined():
mock_load.return_value = torch.ones(10)
service = await TTSService.create("test_output")
- name, path = await service._get_voices_path("voice1+voice2")
+ name, path = await service._get_voice_path("voice1+voice2")
assert name == "voice1+voice2"
- # Verify the path points to a temporary file with expected format
- assert path.startswith("/tmp/")
- assert "voice1+voice2" in path
- assert path.endswith(".pt")
+ assert path.endswith("voice1+voice2.pt")
mock_save.assert_called_once()
diff --git a/charts/kokoro-fastapi/Chart.yaml b/charts/kokoro-fastapi/Chart.yaml
index ed6f6753..bd0cf5da 100644
--- a/charts/kokoro-fastapi/Chart.yaml
+++ b/charts/kokoro-fastapi/Chart.yaml
@@ -1,12 +1,24 @@
apiVersion: v2
name: kokoro-fastapi
-description: A Helm chart for deploying the Kokoro FastAPI TTS service to Kubernetes
+description: A Helm chart for kokoro-fastapi
+
+# A chart can be either an 'application' or a 'library' chart.
+#
+# Application charts are a collection of templates that can be packaged into versioned archives
+# to be deployed.
+#
+# Library charts provide useful utilities or functions for the chart developer. They're included as
+# a dependency of application charts to inject those utilities and functions into the rendering
+# pipeline. Library charts do not define any templates and therefore cannot be deployed.
type: application
-version: 0.3.0
-appVersion: "0.3.0"
-keywords:
- - tts
- - fastapi
- - gpu
- - kokoro
+# This is the chart version. This version number should be incremented each time you make changes
+# to the chart and its templates, including the app version.
+# Versions are expected to follow Semantic Versioning (https://semver.org/)
+version: 0.1.0
+
+# This is the version number of the application being deployed. This version number should be
+# incremented each time you make changes to the application. Versions are not expected to
+# follow Semantic Versioning. They should reflect the version the application is using.
+# It is recommended to use it with quotes.
+appVersion: "1.16.0"
diff --git a/charts/kokoro-fastapi/examples/aks-tls-values.yaml b/charts/kokoro-fastapi/examples/aks-tls-values.yaml
deleted file mode 100644
index 236af0a8..00000000
--- a/charts/kokoro-fastapi/examples/aks-tls-values.yaml
+++ /dev/null
@@ -1,54 +0,0 @@
-# Tested on
-# - Azure AKS with GPU node pool with Nvidia GPU operator
-# - This setup uses 1 ingress and load balances between 2 replicas, enabling simultaneous requests
-#
-# Azure CLI command to create a GPU node pool:
-# az aks nodepool add \
-# --resource-group $AZ_RESOURCE_GROUP \
-# --cluster-name $CLUSTER_NAME \
-# --name t4gpus \
-# --node-vm-size Standard_NC4as_T4_v3 \
-# --node-count 2 \
-# --enable-cluster-autoscaler \
-# --min-count 1 \
-# --max-count 2 \
-# --priority Spot \
-# --eviction-policy Delete \
-# --spot-max-price -1 \
-# --node-taints "sku=gpu:NoSchedule,kubernetes.azure.com/scalesetpriority=spot:NoSchedule" \
-# --skip-gpu-driver-install
-
-kokoroTTS:
- replicaCount: 8
- port: 8880
- tag: v0.2.0
- pullPolicy: IfNotPresent
-
-# Azure specific settings for spot t4 GPU nodes with Nvidia GPU operator
-tolerations:
- - key: "kubernetes.azure.com/scalesetpriority"
- operator: Equal
- value: "spot"
- effect: NoSchedule
- - key: "sku"
- operator: Equal
- value: "gpu"
- effect: NoSchedule
-
-ingress:
- enabled: true
- className: "nginx"
- annotations:
- # Requires cert-manager and external-dns to be in the cluster for TLS and DNS
- cert-manager.io/cluster-issuer: letsencrypt-prod
- external-dns.alpha.kubernetes.io/hostname: your-external-dns-enabled-hostname
- external-dns.alpha.kubernetes.io/cloudflare-proxied: "false"
- hosts:
- - host: your-external-dns-enabled-hostname
- paths:
- - path: /
- pathType: Prefix
- tls:
- - secretName: kokoro-fastapi-tls
- hosts:
- - your-external-dns-enabled-hostname
\ No newline at end of file
diff --git a/charts/kokoro-fastapi/examples/gpu-operator-values.yaml b/charts/kokoro-fastapi/examples/gpu-operator-values.yaml
deleted file mode 100644
index b74667f9..00000000
--- a/charts/kokoro-fastapi/examples/gpu-operator-values.yaml
+++ /dev/null
@@ -1,56 +0,0 @@
-# Follow the official NVIDIA GPU Operator documentation
-# to install the GPU operator with these settings:
-# https://docs.nvidia.com/datacenter/cloud-native/gpu-operator/latest/getting-started.html
-#
-# This example is for a Nvidia T4 16gb GPU node pool with only 1 GPU on each node on Azure AKS.
-# It uses time-slicing to share the a and claim to the system that 1 GPU is 4 GPUs.
-# So each pod has access to a smaller gpu with 4gb of memory.
-#
-devicePlugin: # Remove this if you dont want to use time-slicing
- config:
- create: true
- name: "time-slicing-config"
- default: "any"
- data:
- any: |-
- version: v1
- flags:
- migStrategy: none
- sharing:
- timeSlicing:
- resources:
- - name: nvidia.com/gpu
- replicas: 4
-
-daemonsets:
- tolerations:
- - key: "sku"
- operator: Equal
- value: "gpu"
- effect: NoSchedule
- - key: "kubernetes.azure.com/scalesetpriority"
- operator: Equal
- value: "spot"
- effect: NoSchedule
-
-node-feature-discovery:
- master:
- tolerations:
- - key: "sku"
- operator: Equal
- value: "gpu"
- effect: NoSchedule
- - key: "kubernetes.azure.com/scalesetpriority"
- operator: Equal
- value: "spot"
- effect: NoSchedule
- worker:
- tolerations:
- - key: "sku"
- operator: Equal
- value: "gpu"
- effect: NoSchedule
- - key: "kubernetes.azure.com/scalesetpriority"
- operator: Equal
- value: "spot"
- effect: NoSchedule
\ No newline at end of file
diff --git a/charts/kokoro-fastapi/templates/NOTES.txt b/charts/kokoro-fastapi/templates/NOTES.txt
index bc009b80..88b89806 100644
--- a/charts/kokoro-fastapi/templates/NOTES.txt
+++ b/charts/kokoro-fastapi/templates/NOTES.txt
@@ -13,10 +13,10 @@
NOTE: It may take a few minutes for the LoadBalancer IP to be available.
You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "kokoro-fastapi.fullname" . }}'
export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "kokoro-fastapi.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}")
- echo http://$SERVICE_IP:{{ .Values.kokoroTTS.port }}
+ echo http://$SERVICE_IP:{{ .Values.service.port }}
{{- else if contains "ClusterIP" .Values.service.type }}
export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "kokoro-fastapi.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}")
export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}")
- echo "Visit http://127.0.0.1:8880 to use your application"
- kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8880:$CONTAINER_PORT
+ echo "Visit http://127.0.0.1:8080 to use your application"
+ kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT
{{- end }}
diff --git a/charts/kokoro-fastapi/templates/ingress.yaml b/charts/kokoro-fastapi/templates/ingress.yaml
index a9c9f4e8..09a8fb5a 100644
--- a/charts/kokoro-fastapi/templates/ingress.yaml
+++ b/charts/kokoro-fastapi/templates/ingress.yaml
@@ -1,43 +1,82 @@
{{- if .Values.ingress.enabled -}}
+{{- $fullName := include "kokoro-fastapi.fullname" . -}}
+{{- $svcPort := .Values.service.port -}}
+{{- $rewriteTargets := (list) -}}
+{{- with .Values.ingress.host }}
+ {{- range .endpoints }}
+ {{- $serviceName := default $fullName .serviceName -}}
+ {{- $rewrite := .rewrite | default "none" -}}
+ {{- if not (has $rewrite $rewriteTargets ) -}}
+ {{- $rewriteTargets = append $rewriteTargets $rewrite -}}
+ {{- end -}}
+ {{- end}}
+{{- end }}
+{{- range $key := $rewriteTargets }}
+{{- $expandedRewrite := regexReplaceAll "/(.*)$" $key "slash${1}" -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
- name: {{ include "kokoro-fastapi.fullname" . }}
+{{- if eq $key "none" }}
+ name: {{ $fullName }}
+{{- else }}
+ name: {{ $fullName }}-{{ $expandedRewrite }}
+{{- end }}
labels:
- {{- include "kokoro-fastapi.labels" . | nindent 4 }}
- {{- with .Values.ingress.annotations }}
+ {{- include "kokoro-fastapi.labels" $ | nindent 4 }}
+ {{- if ne $key "none" }}
annotations:
- {{- toYaml . | nindent 4 }}
- {{- end }}
+ nginx.ingress.kubernetes.io/rewrite-target: {{ regexReplaceAll "/$" $key "" }}/$2
+ {{- end }}
spec:
- {{- with .Values.ingress.className }}
- ingressClassName: {{ . }}
- {{- end }}
- {{- if .Values.ingress.tls }}
+{{- if $.Values.ingress.tls }}
tls:
- {{- range .Values.ingress.tls }}
+ {{- range $.Values.ingress.tls }}
- hosts:
- {{- range .hosts }}
+ {{- range .hosts }}
- {{ . | quote }}
- {{- end }}
+ {{- end }}
secretName: {{ .secretName }}
- {{- end }}
{{- end }}
+{{- end }}
rules:
- {{- range .Values.ingress.hosts }}
- - host: {{ .host | quote }}
+ {{- with $.Values.ingress.host }}
+ - host: {{ .name | quote }}
http:
paths:
- {{- range .paths }}
- - path: {{ .path }}
- {{- with .pathType }}
- pathType: {{ . }}
+ {{- range .endpoints }}
+ {{- $serviceName := default $fullName .serviceName -}}
+ {{- $servicePort := default (print "http") .servicePort -}}
+ {{- if eq ( .rewrite | default "none" ) $key }}
+ {{- range .paths }}
+ {{- if not (contains "@" .) }}
+ {{- if eq $key "none" }}
+ - path: {{ . }}
+ {{- else }}
+ - path: {{ regexReplaceAll "(.*)/$" . "${1}" }}(/|$)(.*)
+ {{- end }}
+ pathType: Prefix
+ backend:
+ service:
+ name: "{{ $fullName }}-{{ $serviceName }}"
+ port:
+ number: {{ $servicePort }}
+ {{- else }}
+ {{- $path := . -}}
+ {{- $replicaCount := include "getServiceNameReplicaCount" (dict "global" $.Values "serviceName" $serviceName ) -}}
+ {{- range $count, $e := until ($replicaCount|int) }}
+ - path: {{ $path | replace "@" ( . | toString ) }}(/|$)(.*)
+ pathType: Prefix
+ backend:
+ service:
+ name: "{{ $fullName }}-{{ $serviceName }}-{{ . }}"
+ port:
+ number: {{ $servicePort }}
+ {{- end }}
+ {{- end }}
{{- end }}
- backend:
- service:
- name: {{ include "kokoro-fastapi.fullname" $ }}-kokoro-tts-service
- port:
- number: {{ $.Values.kokoroTTS.port }}
{{- end }}
- {{- end }}
+ {{- end }}
+ {{- end }}
+---
+{{- end }}
{{- end }}
diff --git a/charts/kokoro-fastapi/templates/kokoro-tts-deployment.yaml b/charts/kokoro-fastapi/templates/kokoro-tts-deployment.yaml
index 2178a08a..be1f67b7 100644
--- a/charts/kokoro-fastapi/templates/kokoro-tts-deployment.yaml
+++ b/charts/kokoro-fastapi/templates/kokoro-tts-deployment.yaml
@@ -20,7 +20,7 @@ spec:
labels:
{{- include "kokoro-fastapi.selectorLabels" . | nindent 8 }}
spec:
- {{- with .Values.kokoroTTS.imagePullSecrets }}
+ {{- with .Values.images.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
@@ -49,16 +49,10 @@ spec:
httpGet:
path: /health
port: kokoro-tts-http
- initialDelaySeconds: 30
- periodSeconds: 30
- timeoutSeconds: 5
readinessProbe:
httpGet:
path: /health
port: kokoro-tts-http
- initialDelaySeconds: 30
- periodSeconds: 30
- timeoutSeconds: 5
resources:
{{- toYaml .Values.kokoroTTS.resources | nindent 12 }}
volumeMounts: []
diff --git a/charts/kokoro-fastapi/templates/tests/test-connection.yaml b/charts/kokoro-fastapi/templates/tests/test-connection.yaml
index 8b912c6a..120583fe 100644
--- a/charts/kokoro-fastapi/templates/tests/test-connection.yaml
+++ b/charts/kokoro-fastapi/templates/tests/test-connection.yaml
@@ -11,5 +11,5 @@ spec:
- name: wget
image: busybox
command: ['wget']
- args: ['{{ include "kokoro-fastapi.fullname" . }}:{{ .Values.kokoroTTS.port }}']
+ args: ['{{ include "kokoro-fastapi.fullname" . }}:{{ .Values.service.port }}']
restartPolicy: Never
diff --git a/charts/kokoro-fastapi/values.yaml b/charts/kokoro-fastapi/values.yaml
index e2e37e44..0db2f95d 100644
--- a/charts/kokoro-fastapi/values.yaml
+++ b/charts/kokoro-fastapi/values.yaml
@@ -1,19 +1,12 @@
# Default values for kokoro-fastapi.
# This is a YAML-formatted file.
# Declare variables to be passed into your templates.
-kokoroTTS:
- replicaCount: 1
- # The name of the deployment repository
- repository: "ghcr.io/remsky/kokoro-fastapi-gpu"
- imagePullSecrets: [] # Set if using a private image or getting rate limited
- tag: "latest"
- pullPolicy: Always
- port: 8880
- resources:
- limits:
- nvidia.com/gpu: 1
- requests:
- nvidia.com/gpu: 1
+
+replicaCount: 1
+
+images:
+ pullPolicy: "Always"
+ imagePullSecrets: [ ]
nameOverride: ""
fullnameOverride: ""
@@ -45,21 +38,47 @@ service:
ingress:
enabled: false
- className: "nginx"
+ className: ""
annotations: {}
- # cert-manager.io/cluster-issuer: letsencrypt-prod
- # external-dns.alpha.kubernetes.io/hostname: kokoro.example.com
- # external-dns.alpha.kubernetes.io/cloudflare-proxied: "false"
- hosts:
- - host: kokoro.example.com
- paths:
- - path: /
- pathType: Prefix
+ # kubernetes.io/ingress.class: nginx
+ # kubernetes.io/tls-acme: "true"
+ host:
+ name: kokoro.example.com
+ endpoints:
+ - paths:
+ - "/"
+ serviceName: "fastapi"
+ servicePort: 8880
tls: []
- # - secretName: kokoro-fastapi-tls
+ # - secretName: chart-example-tls
# hosts:
- # - kokoro.example.com
+ # - chart-example.local
+
+kokoroTTS:
+ repository: "ghcr.io/remsky/kokoro-fastapi-gpu"
+ tag: "latest"
+ pullPolicy: Always
+ serviceName: "fastapi"
+ port: 8880
+ replicaCount: 1
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+ requests:
+ nvidia.com/gpu: 1
+
+
+ # We usually recommend not to specify default resources and to leave this as a conscious
+ # choice for the user. This also increases chances charts run on environments with little
+ # resources, such as Minikube. If you do want to specify resources, uncomment the following
+ # lines, adjust them as necessary, and remove the curly braces after 'resources:'.
+ # limits:
+ # cpu: 100m
+ # memory: 128Mi
+ # requests:
+ # cpu: 100m
+ # memory: 128Mi
autoscaling:
enabled: false
diff --git a/dev/Test Phon.py b/dev/Test Phon.py
deleted file mode 100644
index d3ba7839..00000000
--- a/dev/Test Phon.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import base64
-import json
-
-import pydub
-import requests
-
-def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"):
- """Generate audio from phonemes"""
- response = requests.post(
- "http://localhost:8880/dev/generate_from_phonemes",
- json={"phonemes": phonemes, "voice": voice},
- headers={"Accept": "audio/wav"}
- )
- if response.status_code != 200:
- print(f"Error: {response.text}")
- return None
- return response.content
-
-
-
-
-with open(f"outputnostreammoney.wav", "wb") as f:
- f.write(generate_audio_from_phonemes(r"mɪsəki ɪz ɐn ɪkspˌɛɹəmˈɛntᵊl ʤˈitəpˈi ˈɛnʤən dəzˈInd tə pˈWəɹ fjˈuʧəɹ vˈɜɹʒənz ʌv kəkˈɔɹO mˈɑdᵊlz."))
\ No newline at end of file
diff --git a/dev/Test Threads.py b/dev/Test Threads.py
deleted file mode 100644
index bb7caee1..00000000
--- a/dev/Test Threads.py
+++ /dev/null
@@ -1,378 +0,0 @@
-#!/usr/bin/env python3
-# Compatible with both Windows and Linux
-"""
-Kokoro TTS Race Condition Test
-
-This script creates multiple concurrent requests to a Kokoro TTS service
-to reproduce a race condition where audio outputs don't match the requested text.
-Each thread generates a simple numbered sentence, which should make mismatches
-easy to identify through listening.
-
-To run:
-python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880
-"""
-
-import argparse
-import base64
-import concurrent.futures
-import json
-import os
-import sys
-import time
-import wave
-from pathlib import Path
-
-import requests
-
-
-def setup_args():
- """Parse command line arguments"""
- parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
- parser.add_argument(
- "--url",
- default="http://localhost:8880",
- help="Base URL of the Kokoro TTS service",
- )
- parser.add_argument(
- "--threads", type=int, default=8, help="Number of concurrent threads to use"
- )
- parser.add_argument(
- "--iterations", type=int, default=5, help="Number of iterations per thread"
- )
- parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS")
- parser.add_argument(
- "--output-dir",
- default="./tts_test_output",
- help="Directory to save output files",
- )
- parser.add_argument("--debug", action="store_true", help="Enable debug logging")
- return parser.parse_args()
-
-
-def generate_test_sentence(thread_id, iteration):
- """Generate a simple test sentence with numbers to make mismatches easily identifiable"""
- return (
- f"This is test sentence number {thread_id}-{iteration}. "
- f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
- )
-
-
-def log_message(message, debug=False, is_error=False):
- """Log messages with timestamps"""
- timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
- prefix = "[ERROR]" if is_error else "[INFO]"
- if is_error or debug:
- print(f"{prefix} {timestamp} - {message}")
- sys.stdout.flush() # Ensure logs are visible in Docker output
-
-
-def request_tts(url, test_id, text, voice, output_dir, debug=False):
- """Request TTS from the Kokoro API and save the WAV output"""
- start_time = time.time()
- output_file = os.path.join(output_dir, f"test_{test_id}.wav")
- text_file = os.path.join(output_dir, f"test_{test_id}.txt")
-
- # Log output paths for debugging
- log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug)
- log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug)
-
- # Save the text for later comparison
- try:
- with open(text_file, "w") as f:
- f.write(text)
- log_message(f"Thread {test_id}: Successfully saved text file", debug)
- except Exception as e:
- log_message(
- f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True
- )
-
- # Make the TTS request
- try:
- log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug)
-
- response = requests.post(
- f"{url}/v1/audio/speech",
- json={
- "model": "kokoro",
- "input": text,
- "voice": voice,
- "response_format": "wav",
- },
- headers={"Accept": "audio/wav"},
- timeout=60, # Increase timeout to 60 seconds
- )
-
- log_message(
- f"Thread {test_id}: Response status code: {response.status_code}", debug
- )
- log_message(
- f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}",
- debug,
- )
- log_message(
- f"Thread {test_id}: Response content length: {len(response.content)} bytes",
- debug,
- )
-
- if response.status_code != 200:
- log_message(
- f"Thread {test_id}: API error: {response.status_code} - {response.text}",
- debug,
- is_error=True,
- )
- return False
-
- # Check if we got valid audio data
- if (
- len(response.content) < 100
- ): # Sanity check - WAV files should be larger than this
- log_message(
- f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes",
- debug,
- is_error=True,
- )
- log_message(
- f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}",
- debug,
- is_error=True,
- )
- return False
-
- # Save the audio output with explicit error handling
- try:
- with open(output_file, "wb") as f:
- bytes_written = f.write(response.content)
- log_message(
- f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}",
- debug,
- )
-
- # Verify the WAV file exists and has content
- if os.path.exists(output_file):
- file_size = os.path.getsize(output_file)
- log_message(
- f"Thread {test_id}: Verified file exists with size: {file_size} bytes",
- debug,
- )
-
- # Validate WAV file by reading its headers
- try:
- with wave.open(output_file, "rb") as wav_file:
- channels = wav_file.getnchannels()
- sample_width = wav_file.getsampwidth()
- framerate = wav_file.getframerate()
- frames = wav_file.getnframes()
- log_message(
- f"Thread {test_id}: Valid WAV file - channels: {channels}, "
- f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}",
- debug,
- )
- except Exception as wav_error:
- log_message(
- f"Thread {test_id}: Invalid WAV file: {str(wav_error)}",
- debug,
- is_error=True,
- )
- else:
- log_message(
- f"Thread {test_id}: File was not created: {output_file}",
- debug,
- is_error=True,
- )
- except Exception as save_error:
- log_message(
- f"Thread {test_id}: Error saving audio file: {str(save_error)}",
- debug,
- is_error=True,
- )
- return False
-
- end_time = time.time()
- log_message(
- f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)",
- debug,
- )
- return True
-
- except requests.exceptions.Timeout:
- log_message(f"Thread {test_id}: Request timed out", debug, is_error=True)
- return False
- except Exception as e:
- log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True)
- return False
-
-
-def worker_task(thread_id, args):
- """Worker task for each thread"""
- for i in range(args.iterations):
- iteration = i + 1
- test_id = f"{thread_id:02d}_{iteration:02d}"
- text = generate_test_sentence(thread_id, iteration)
- success = request_tts(
- args.url, test_id, text, args.voice, args.output_dir, args.debug
- )
-
- if not success:
- log_message(
- f"Thread {thread_id}: Iteration {iteration} failed",
- args.debug,
- is_error=True,
- )
-
- # Small delay between iterations to avoid overwhelming the API
- time.sleep(0.1)
-
-
-def run_test(args):
- """Run the test with the specified parameters"""
- # Ensure output directory exists and check permissions
- os.makedirs(args.output_dir, exist_ok=True)
-
- # Test write access to the output directory
- test_file = os.path.join(args.output_dir, "write_test.txt")
- try:
- with open(test_file, "w") as f:
- f.write("Testing write access\n")
- os.remove(test_file)
- log_message(
- f"Successfully verified write access to output directory: {args.output_dir}"
- )
- except Exception as e:
- log_message(
- f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}",
- is_error=True,
- )
- log_message(f"Current directory: {os.getcwd()}", is_error=True)
- log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
-
- # Test connection to Kokoro TTS service
- try:
- response = requests.get(f"{args.url}/health", timeout=5)
- if response.status_code == 200:
- log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
- else:
- log_message(
- f"Warning: Kokoro TTS service health check returned status {response.status_code}",
- is_error=True,
- )
- except Exception as e:
- log_message(
- f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}",
- is_error=True,
- )
-
- # Record start time
- start_time = time.time()
- log_message(
- f"Starting test with {args.threads} threads, {args.iterations} iterations per thread"
- )
-
- # Create and start worker threads
- with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
- futures = []
- for thread_id in range(1, args.threads + 1):
- futures.append(executor.submit(worker_task, thread_id, args))
-
- # Wait for all tasks to complete
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- log_message(
- f"Thread execution failed: {str(e)}", args.debug, is_error=True
- )
-
- # Record end time and print summary
- end_time = time.time()
- total_time = end_time - start_time
- total_requests = args.threads * args.iterations
- log_message(f"Test completed in {total_time:.2f} seconds")
- log_message(f"Total requests: {total_requests}")
- log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
- log_message(f"Requests per second: {total_requests / total_time:.2f}")
- log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
- log_message(
- "To verify, listen to the audio files and check if they match the text files"
- )
- log_message(
- "If you hear audio describing a different test number than the filename, you've found a race condition"
- )
-
-
-def analyze_audio_files(output_dir):
- """Provide summary of the generated audio files"""
- # Look for both WAV and TXT files
- wav_files = list(Path(output_dir).glob("*.wav"))
- txt_files = list(Path(output_dir).glob("*.txt"))
-
- log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
-
- if len(wav_files) == 0:
- log_message(
- "No WAV files found! This indicates the TTS service requests may be failing.",
- is_error=True,
- )
- log_message(
- "Check the connection to the TTS service and the response status codes above.",
- is_error=True,
- )
-
- file_stats = []
- for wav_path in wav_files:
- try:
- with wave.open(str(wav_path), "rb") as wav_file:
- frames = wav_file.getnframes()
- rate = wav_file.getframerate()
- duration = frames / rate
-
- # Get corresponding text
- text_path = wav_path.with_suffix(".txt")
- if text_path.exists():
- with open(text_path, "r") as text_file:
- text = text_file.read().strip()
- else:
- text = "N/A"
-
- file_stats.append(
- {"filename": wav_path.name, "duration": duration, "text": text}
- )
- except Exception as e:
- log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
-
- # Print summary table
- if file_stats:
- log_message("\nAudio File Summary:")
- log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
- log_message("-" * 92)
- for stat in file_stats:
- log_message(
- f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57] + '...' if len(stat['text']) > 60 else stat['text']:<60}"
- )
-
- # List missing WAV files where text files exist
- missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
- if missing_wavs:
- log_message(
- f"\nFound {len(missing_wavs)} text files without corresponding WAV files:",
- is_error=True,
- )
- for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
- log_message(f" - {stem}.txt (no WAV file)", is_error=True)
- if len(missing_wavs) > 10:
- log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True)
-
-
-if __name__ == "__main__":
- args = setup_args()
- run_test(args)
- analyze_audio_files(args.output_dir)
-
- log_message("\nNext Steps:")
- log_message("1. Listen to the generated audio files")
- log_message("2. Verify if each audio correctly says its ID number")
- log_message(
- "3. Check for any mismatches between the audio content and the text files"
- )
- log_message(
- "4. If mismatches are found, you've successfully reproduced the race condition"
- )
diff --git a/dev/Test copy 2.py b/dev/Test copy 2.py
deleted file mode 100644
index 52634ec2..00000000
--- a/dev/Test copy 2.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import base64
-import json
-
-import pydub
-import requests
-
-text = """Running on localhost:7860"""
-
-
-Type = "wav"
-response = requests.post(
- "http://localhost:8880/dev/captioned_speech",
- json={
- "model": "kokoro",
- "input": text,
- "voice": "af_heart+af_sky",
- "speed": 1.0,
- "response_format": Type,
- "stream": True,
- },
- stream=True,
-)
-
-f = open(f"outputstream.{Type}", "wb")
-for chunk in response.iter_lines(decode_unicode=True):
- if chunk:
- temp_json = json.loads(chunk)
- if temp_json["timestamps"] != []:
- chunk_json = temp_json
-
- # Decode base 64 stream to bytes
- chunk_audio = base64.b64decode(temp_json["audio"].encode("utf-8"))
-
- # Process streaming chunks
- f.write(chunk_audio)
-
- # Print word level timestamps
- print(chunk_json["timestamps"])
diff --git a/dev/Test money.py b/dev/Test money.py
deleted file mode 100644
index 57d1fa64..00000000
--- a/dev/Test money.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import base64
-import json
-
-import requests
-
-text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。"""
-
-
-Type = "wav"
-
-response = requests.post(
- "http://localhost:8880/v1/audio/speech",
- json={
- "model": "kokoro",
- "input": text,
- "voice": "zf_xiaobei",
- "speed": 1.0,
- "response_format": Type,
- "stream": False,
- },
- stream=True,
-)
-
-with open(f"outputnostreammoney.{Type}", "wb") as f:
- f.write(response.content)
diff --git a/dev/Test num.py b/dev/Test num.py
deleted file mode 100644
index f7485aee..00000000
--- a/dev/Test num.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import re
-
-import inflect
-from text_to_num import text2num
-from torch import mul
-
-INFLECT_ENGINE = inflect.engine()
-
-
-def conditional_int(number: float, threshold: float = 0.00001):
- if abs(round(number) - number) < threshold:
- return int(round(number))
- return number
-
-
-def handle_money(m: re.Match[str]) -> str:
- """Convert money expressions to spoken form"""
-
- bill = "dollar" if m.group(2) == "$" else "pound"
- coin = "cent" if m.group(2) == "$" else "pence"
- number = m.group(3)
-
- multiplier = m.group(4)
- try:
- number = float(number)
- except:
- return m.group()
-
- if m.group(1) == "-":
- number *= -1
-
- if number % 1 == 0 or multiplier != "":
- text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
- else:
- sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
-
- text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
-
- return text_number
-
-
-text = re.sub(
- r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
- handle_money,
- "he administration has offered up a platter of repression for more than a year and is still slated to lose -$5.3 billion",
-)
-print(text)
diff --git a/docker-bake.hcl b/docker-bake.hcl
index e29599a0..76eb4f04 100644
--- a/docker-bake.hcl
+++ b/docker-bake.hcl
@@ -1,14 +1,14 @@
# Variables for reuse
variable "VERSION" {
- default = "latest"
+ default = "20250826"
}
variable "REGISTRY" {
- default = "ghcr.io"
+ default = "everymatrix.jfrog.io"
}
variable "OWNER" {
- default = "remsky"
+ default = "emlab-docker"
}
variable "REPO" {
@@ -43,7 +43,7 @@ target "_gpu_base" {
# CPU target with multi-platform support
target "cpu" {
inherits = ["_cpu_base"]
- platforms = ["linux/amd64", "linux/arm64"]
+ platforms = ["linux/amd64"]
tags = [
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
@@ -53,10 +53,9 @@ target "cpu" {
# GPU target with multi-platform support
target "gpu" {
inherits = ["_gpu_base"]
- platforms = ["linux/amd64", "linux/arm64"]
+ platforms = ["linux/amd64"]
tags = [
- "${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
- "${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
+ "everymatrix.jfrog.io/emlab-docker/ayida/kokoro:${VERSION}"
]
}
diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile
index 15495429..d770a6c8 100644
--- a/docker/cpu/Dockerfile
+++ b/docker/cpu/Dockerfile
@@ -1,17 +1,26 @@
FROM python:3.10-slim
# Install dependencies and check espeak location
-# Rust is required to build sudachipy and pyopenjtalk-plus
-RUN apt-get update -y && \
- apt-get install -y espeak-ng espeak-ng-data git libsndfile1 curl ffmpeg g++ && \
- apt-get clean && rm -rf /var/lib/apt/lists/* && \
- mkdir -p /usr/share/espeak-ng-data && \
- ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/ && \
- curl -LsSf https://astral.sh/uv/install.sh | sh && \
+RUN apt-get update && apt-get install -y \
+ espeak-ng \
+ espeak-ng-data \
+ git \
+ libsndfile1 \
+ curl \
+ ffmpeg \
+ g++ \
+&& apt-get clean \
+&& rm -rf /var/lib/apt/lists/* \
+&& mkdir -p /usr/share/espeak-ng-data \
+&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
+
+# Install UV using the installer script
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv /usr/local/bin/ && \
- mv /root/.local/bin/uvx /usr/local/bin/ && \
- curl https://sh.rustup.rs -sSf | sh -s -- -y && \
- useradd -m -u 1000 appuser && \
+ mv /root/.local/bin/uvx /usr/local/bin/
+
+# Create non-root user and set up directories and permissions
+RUN useradd -m -u 1000 appuser && \
mkdir -p /app/api/src/models/v1_0 && \
chown -R appuser:appuser /app
@@ -21,9 +30,10 @@ WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
-# Install dependencies with CPU extras
-RUN uv venv --python 3.10 && \
- uv sync --extra cpu --no-cache
+# Install dependencies
+RUN --mount=type=cache,target=/root/.cache/uv \
+ uv venv --python 3.10 && \
+ uv sync --extra cpu
# Copy project files including models
COPY --chown=appuser:appuser api ./api
@@ -32,15 +42,14 @@ COPY --chown=appuser:appuser docker/scripts/ ./
RUN chmod +x ./entrypoint.sh
# Set environment variables
-ENV PATH="/home/appuser/.cargo/bin:/app/.venv/bin:$PATH" \
- PYTHONUNBUFFERED=1 \
+ENV PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/app/api \
+ PATH="/app/.venv/bin:$PATH" \
UV_LINK_MODE=copy \
USE_GPU=false \
PHONEMIZER_ESPEAK_PATH=/usr/bin \
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
- ESPEAK_DATA_PATH=/usr/share/espeak-ng-data \
- DEVICE="cpu"
+ ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
ENV DOWNLOAD_MODEL=true
# Download model if enabled
@@ -48,5 +57,6 @@ RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
python download_model.py --output api/src/models/v1_0; \
fi
+ENV DEVICE="cpu"
# Run FastAPI server through entrypoint.sh
CMD ["./entrypoint.sh"]
diff --git a/docker/cpu/docker-compose.yml b/docker/cpu/docker-compose.yml
index 8ca8821b..ed15540b 100644
--- a/docker/cpu/docker-compose.yml
+++ b/docker/cpu/docker-compose.yml
@@ -20,7 +20,7 @@ services:
# # Gradio UI service [Comment out everything below if you don't need it]
# gradio-ui:
- # image: ghcr.io/remsky/kokoro-fastapi-ui:v${VERSION}
+ # image: ghcr.io/remsky/kokoro-fastapi-ui:v0.2.0
# # Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../../ui
diff --git a/docker/gpu/Dockerfile b/docker/gpu/Dockerfile
index 572540cb..6e19679b 100644
--- a/docker/gpu/Dockerfile
+++ b/docker/gpu/Dockerfile
@@ -1,15 +1,29 @@
-FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.1-base-ubuntu24.04
+FROM --platform=$BUILDPLATFORM everymatrix.jfrog.io/emlab-docker-remote-hub/nvidia/cuda:12.9.1-base-ubuntu24.04
+# Set non-interactive frontend
+ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
-RUN apt-get update -y && \
- apt-get install -y python3.10 python3-venv espeak-ng espeak-ng-data git libsndfile1 curl ffmpeg g++ && \
- apt-get clean && rm -rf /var/lib/apt/lists/* && \
- mkdir -p /usr/share/espeak-ng-data && \
- ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/ && \
- curl -LsSf https://astral.sh/uv/install.sh | sh && \
+RUN apt-get update && apt-get install -y \
+ python3.10 \
+ python3-venv \
+ espeak-ng \
+ espeak-ng-data \
+ git \
+ libsndfile1 \
+ curl \
+ ffmpeg \
+ g++ \
+ && apt-get clean && rm -rf /var/lib/apt/lists/* \
+ && mkdir -p /usr/share/espeak-ng-data \
+ && ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
+
+# Install UV using the installer script
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv /usr/local/bin/ && \
- mv /root/.local/bin/uvx /usr/local/bin/ && \
- useradd -m -u 1001 appuser && \
+ mv /root/.local/bin/uvx /usr/local/bin/
+
+# Create non-root user and set up directories and permissions
+RUN useradd -m -u 1001 appuser && \
mkdir -p /app/api/src/models/v1_0 && \
chown -R appuser:appuser /app
@@ -19,9 +33,14 @@ WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
-# Install dependencies with GPU extras
-RUN uv venv --python 3.10 && \
- uv sync --extra gpu --no-cache
+ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \
+ PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
+ ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
+
+# Install dependencies with GPU extras (using cache mounts)
+RUN --mount=type=cache,target=/root/.cache/uv \
+ uv venv --python 3.10 && \
+ uv sync --extra gpu
# Copy project files including models
COPY --chown=appuser:appuser api ./api
@@ -31,21 +50,26 @@ RUN chmod +x ./entrypoint.sh
# Set all environment variables in one go
-ENV PATH="/app/.venv/bin:$PATH" \
- PYTHONUNBUFFERED=1 \
+ENV PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/app/api \
+ PATH="/app/.venv/bin:$PATH" \
UV_LINK_MODE=copy \
- USE_GPU=true \
- PHONEMIZER_ESPEAK_PATH=/usr/bin \
- PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
- ESPEAK_DATA_PATH=/usr/share/espeak-ng-data \
- DEVICE="gpu"
-
+ USE_GPU=true
+
ENV DOWNLOAD_MODEL=true
# Download model if enabled
RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
+ export HTTPS_PROXY=http://10.0.10.4:3128; \
python download_model.py --output api/src/models/v1_0; \
fi
+ENV DEVICE="gpu"
# Run FastAPI server through entrypoint.sh
CMD ["./entrypoint.sh"]
+
+USER root
+
+RUN rm -f /usr/local/cuda-12.9/compat/libcuda.so* && \
+ rm -f /usr/local/cuda-12.9/compat/libnvidia-ml.so*
+
+USER appuser
\ No newline at end of file
diff --git a/docker/gpu/docker-compose.yml b/docker/gpu/docker-compose.yml
index 9faddd89..762aca69 100644
--- a/docker/gpu/docker-compose.yml
+++ b/docker/gpu/docker-compose.yml
@@ -1,13 +1,12 @@
name: kokoro-tts-gpu
services:
kokoro-tts:
- # image: ghcr.io/remsky/kokoro-fastapi-gpu:v${VERSION}
+ # image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.0
build:
context: ../..
dockerfile: docker/gpu/Dockerfile
volumes:
- ../../api:/app/api
- user: "1001:1001" # Ensure container runs as UID 1001 (appuser)
ports:
- "8880:8880"
environment:
@@ -24,7 +23,7 @@ services:
# # Gradio UI service
# gradio-ui:
- # image: ghcr.io/remsky/kokoro-fastapi-ui:v${VERSION}
+ # image: ghcr.io/remsky/kokoro-fastapi-ui:v0.2.0
# # Uncomment below to build from source instead of using the released image
# # build:
# # context: ../../ui
diff --git a/docker/scripts/download_model.py b/docker/scripts/download_model.py
index 67a9409c..c406a17d 100644
--- a/docker/scripts/download_model.py
+++ b/docker/scripts/download_model.py
@@ -11,11 +11,11 @@
def verify_files(model_path: str, config_path: str) -> bool:
"""Verify that model files exist and are valid.
-
+
Args:
model_path: Path to model file
config_path: Path to config file
-
+
Returns:
True if files exist and are valid
"""
@@ -25,15 +25,15 @@ def verify_files(model_path: str, config_path: str) -> bool:
return False
if not os.path.exists(config_path):
return False
-
+
# Verify config file is valid JSON
with open(config_path) as f:
config = json.load(f)
-
+
# Check model file size (should be non-zero)
if os.path.getsize(model_path) == 0:
return False
-
+
return True
except Exception:
return False
@@ -41,45 +41,45 @@ def verify_files(model_path: str, config_path: str) -> bool:
def download_model(output_dir: str) -> None:
"""Download model files from GitHub release.
-
+
Args:
output_dir: Directory to save model files
"""
try:
# Create output directory
os.makedirs(output_dir, exist_ok=True)
-
+
# Define file paths
model_file = "kokoro-v1_0.pth"
config_file = "config.json"
model_path = os.path.join(output_dir, model_file)
config_path = os.path.join(output_dir, config_file)
-
+
# Check if files already exist and are valid
if verify_files(model_path, config_path):
logger.info("Model files already exist and are valid")
return
-
+
logger.info("Downloading Kokoro v1.0 model files")
-
+
# GitHub release URLs (to be updated with v0.2.0 release)
base_url = "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.4"
model_url = f"{base_url}/{model_file}"
config_url = f"{base_url}/{config_file}"
-
+
# Download files
logger.info("Downloading model file...")
urlretrieve(model_url, model_path)
-
+
logger.info("Downloading config file...")
urlretrieve(config_url, config_path)
-
+
# Verify downloaded files
if not verify_files(model_path, config_path):
raise RuntimeError("Failed to verify downloaded files")
-
+
logger.info(f"✓ Model files prepared in {output_dir}")
-
+
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
@@ -88,15 +88,17 @@ def download_model(output_dir: str) -> None:
def main():
"""Main entry point."""
import argparse
-
+
parser = argparse.ArgumentParser(description="Download Kokoro v1.0 model")
parser.add_argument(
- "--output", required=True, help="Output directory for model files"
+ "--output",
+ required=True,
+ help="Output directory for model files"
)
-
+
args = parser.parse_args()
download_model(args.output)
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/examples/stream_tts_playback.py b/examples/stream_tts_playback.py
index 4e5ac5ea..b4a34d92 100644
--- a/examples/stream_tts_playback.py
+++ b/examples/stream_tts_playback.py
@@ -123,7 +123,7 @@ def main():
with open(wells_path, "r", encoding="utf-8") as f:
full_text = f.read()
# Take first few paragraphs
- text = " ".join(full_text.split("\n\n")[1:3])
+ text = " ".join(full_text.split("\n\n")[:2])
print("\nStarting TTS stream playback...")
print(f"Text length: {len(text)} characters")
diff --git a/pyproject.toml b/pyproject.toml
index ffbefe54..7d77ec5a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "kokoro-fastapi"
-version = "0.3.0"
+version = "0.1.4"
description = "FastAPI TTS Service"
readme = "README.md"
requires-python = ">=3.10"
@@ -31,31 +31,29 @@ dependencies = [
"matplotlib>=3.10.0",
"mutagen>=1.47.0",
"psutil>=6.1.1",
- "espeakng-loader==0.2.4",
- "kokoro==0.9.2",
- "misaki[en,ja,ko,zh]==0.9.3",
- "spacy==3.8.5",
- "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
+ "kokoro @ git+https://github.com/hexgrad/kokoro.git@31a2b6337b8c1b1418ef68c48142328f640da938",
+ 'misaki[en,ja,ko,zh] @ git+https://github.com/hexgrad/misaki.git@ebc76c21b66c5fc4866ed0ec234047177b396170',
+ "spacy==3.7.2",
+ "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
"inflect>=7.5.0",
"phonemizer-fork>=3.3.2",
- "av>=14.2.0",
- "text2num>=2.5.1",
+ "av>=14.1.0",
]
[project.optional-dependencies]
gpu = [
- "torch==2.7.1+cu128",
+ "torch==2.8.0+cu129",
]
cpu = [
- "torch==2.7.1",
+ "torch==2.6.0",
]
test = [
- "pytest==8.3.5",
- "pytest-cov==6.0.0",
+ "pytest==8.0.0",
+ "pytest-cov==4.1.0",
"httpx==0.26.0",
- "pytest-asyncio==0.25.3",
+ "pytest-asyncio==0.23.5",
+ "openai>=1.59.6",
"tomli>=2.0.1",
- "jinja2>=3.1.6"
]
[tool.uv]
@@ -79,7 +77,7 @@ explicit = true
[[tool.uv.index]]
name = "pytorch-cuda"
-url = "https://download.pytorch.org/whl/cu128"
+url = "https://download.pytorch.org/whl/cu129"
explicit = true
[build-system]
@@ -93,5 +91,5 @@ packages.find = {where = ["api/src"], namespaces = true}
[tool.pytest.ini_options]
testpaths = ["api/tests", "ui/tests"]
python_files = ["test_*.py"]
-addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace"
-asyncio_mode = "auto"
+addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc"
+asyncio_mode = "strict"
diff --git a/scripts/fix_misaki.py b/scripts/fix_misaki.py
deleted file mode 100644
index e7eb2a69..00000000
--- a/scripts/fix_misaki.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""
-Patch for misaki package to fix the EspeakWrapper.set_data_path issue.
-"""
-
-import importlib.util
-import os
-import sys
-
-# Find the misaki package
-try:
- import misaki
-
- misaki_path = os.path.dirname(misaki.__file__)
- print(f"Found misaki package at: {misaki_path}")
-except ImportError:
- print("Misaki package not found. Make sure it's installed.")
- sys.exit(1)
-
-# Path to the espeak.py file
-espeak_file = os.path.join(misaki_path, "espeak.py")
-
-if not os.path.exists(espeak_file):
- print(f"Could not find {espeak_file}")
- sys.exit(1)
-
-# Read the current content
-with open(espeak_file, "r") as f:
- content = f.read()
-
-# Check if the problematic line exists
-if "EspeakWrapper.set_data_path(espeakng_loader.get_data_path())" in content:
- # Replace the problematic line
- new_content = content.replace(
- "EspeakWrapper.set_data_path(espeakng_loader.get_data_path())",
- "# Fixed line to use data_path attribute instead of set_data_path method\n"
- "EspeakWrapper.data_path = espeakng_loader.get_data_path()",
- )
-
- # Write the modified content back
- with open(espeak_file, "w") as f:
- f.write(new_content)
-
- print(f"Successfully patched {espeak_file}")
-else:
- print(f"The problematic line was not found in {espeak_file}")
- print("The file may have already been patched or the issue is different.")
diff --git a/scripts/update_badges.py b/scripts/update_badges.py
index bd41b43b..dd5d6c43 100644
--- a/scripts/update_badges.py
+++ b/scripts/update_badges.py
@@ -1,139 +1,138 @@
import re
import subprocess
-from pathlib import Path
-
import tomli
-
+from pathlib import Path
def extract_dependency_info():
- """Extract version for kokoro and misaki from pyproject.toml"""
+ """Extract version and commit hash for kokoro and misaki from pyproject.toml"""
with open("pyproject.toml", "rb") as f:
pyproject = tomli.load(f)
-
+
deps = pyproject["project"]["dependencies"]
info = {}
- kokoro_found = False
- misaki_found = False
-
+
+ # Extract kokoro info
for dep in deps:
- # Match kokoro==version
- kokoro_match = re.match(r"^kokoro==(.+)$", dep)
- if kokoro_match:
- info["kokoro"] = {"version": kokoro_match.group(1)}
- kokoro_found = True
-
- # Match misaki[...] ==version or misaki==version
- misaki_match = re.match(r"^misaki(?:\[.*?\])?==(.+)$", dep)
- if misaki_match:
- info["misaki"] = {"version": misaki_match.group(1)}
- misaki_found = True
-
- # Stop if both found
- if kokoro_found and misaki_found:
- break
-
- if not kokoro_found:
- raise ValueError("Kokoro version not found in pyproject.toml dependencies")
- if not misaki_found:
- raise ValueError("Misaki version not found in pyproject.toml dependencies")
-
+ if dep.startswith("kokoro @"):
+ # Extract version from the dependency string if available
+ version_match = re.search(r"kokoro @ git\+https://github\.com/hexgrad/kokoro\.git@", dep)
+ if version_match:
+ # If no explicit version, use v0.7.9 as shown in the README
+ version = "v0.7.9"
+ commit_match = re.search(r"@([a-f0-9]{7})", dep)
+ if commit_match:
+ info["kokoro"] = {
+ "version": version,
+ "commit": commit_match.group(1)
+ }
+ elif dep.startswith("misaki["):
+ # Extract version from the dependency string if available
+ version_match = re.search(r"misaki\[.*?\] @ git\+https://github\.com/hexgrad/misaki\.git@", dep)
+ if version_match:
+ # If no explicit version, use v0.7.9 as shown in the README
+ version = "v0.7.9"
+ commit_match = re.search(r"@([a-f0-9]{7})", dep)
+ if commit_match:
+ info["misaki"] = {
+ "version": version,
+ "commit": commit_match.group(1)
+ }
+
return info
-
def run_pytest_with_coverage():
"""Run pytest with coverage and return the results"""
try:
# Run pytest with coverage
result = subprocess.run(
- ["pytest", "--cov=api", "-v"], capture_output=True, text=True, check=True
+ ["pytest", "--cov=api", "-v"],
+ capture_output=True,
+ text=True,
+ check=True
)
-
+
# Extract test results
test_output = result.stdout
passed_tests = len(re.findall(r"PASSED", test_output))
-
+
# Extract coverage from .coverage file
coverage_output = subprocess.run(
- ["coverage", "report"], capture_output=True, text=True, check=True
+ ["coverage", "report"],
+ capture_output=True,
+ text=True,
+ check=True
).stdout
-
+
# Extract total coverage percentage
coverage_match = re.search(r"TOTAL\s+\d+\s+\d+\s+(\d+)%", coverage_output)
coverage_percentage = coverage_match.group(1) if coverage_match else "0"
-
+
return passed_tests, coverage_percentage
except subprocess.CalledProcessError as e:
print(f"Error running tests: {e}")
print(f"Output: {e.output}")
return 0, "0"
-
def update_readme_badges(passed_tests, coverage_percentage, dep_info):
"""Update the badges in the README file"""
readme_path = Path("README.md")
if not readme_path.exists():
print("README.md not found")
return False
-
+
content = readme_path.read_text()
-
+
# Update tests badge
content = re.sub(
- r"!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)",
- f"",
- content,
+ r'!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)',
+ f'',
+ content
)
-
+
# Update coverage badge
content = re.sub(
- r"!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)",
- f"",
- content,
+ r'!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)',
+ f'',
+ content
)
-
+
# Update kokoro badge
if "kokoro" in dep_info:
- # Find badge like kokoro-v0.9.2::abcdefg-BB5420 or kokoro-v0.9.2-BB5420
- kokoro_version = dep_info["kokoro"]["version"]
content = re.sub(
- r"(!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-)[^)-]+(-BB5420\))",
- lambda m: f"{m.group(1)}{kokoro_version}{m.group(2)}",
- content,
+ r'!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-[^)]+\)',
+ f'',
+ content
)
-
+
# Update misaki badge
if "misaki" in dep_info:
- # Find badge like misaki-v0.9.3::abcdefg-B8860B or misaki-v0.9.3-B8860B
- misaki_version = dep_info["misaki"]["version"]
content = re.sub(
- r"(!\[Misaki\]\(https://img\.shields\.io/badge/misaki-)[^)-]+(-B8860B\))",
- lambda m: f"{m.group(1)}{misaki_version}{m.group(2)}",
- content,
+ r'!\[Misaki\]\(https://img\.shields\.io/badge/misaki-[^)]+\)',
+ f'',
+ content
)
-
+
readme_path.write_text(content)
return True
-
def main():
# Get dependency info
dep_info = extract_dependency_info()
-
+
# Run tests and get coverage
passed_tests, coverage_percentage = run_pytest_with_coverage()
-
+
# Update badges
if update_readme_badges(passed_tests, coverage_percentage, dep_info):
print(f"Updated badges:")
print(f"- Tests: {passed_tests} passed")
print(f"- Coverage: {coverage_percentage}%")
if "kokoro" in dep_info:
- print(f"- Kokoro: {dep_info['kokoro']['version']}")
+ print(f"- Kokoro: {dep_info['kokoro']['version']}::{dep_info['kokoro']['commit']}")
if "misaki" in dep_info:
- print(f"- Misaki: {dep_info['misaki']['version']}")
+ print(f"- Misaki: {dep_info['misaki']['version']}::{dep_info['misaki']['commit']}")
else:
print("Failed to update badges")
-
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/scripts/update_version.py b/scripts/update_version.py
deleted file mode 100755
index e204a56f..00000000
--- a/scripts/update_version.py
+++ /dev/null
@@ -1,234 +0,0 @@
-#!/usr/bin/env python3
-"""
-Version Update Script
-
-This script reads the version from the VERSION file and updates references
-in pyproject.toml, the Helm chart, and README.md.
-"""
-
-import re
-from pathlib import Path
-
-import yaml
-
-# Get the project root directory
-ROOT_DIR = Path(__file__).parent.parent
-
-# --- Configuration ---
-VERSION_FILE = ROOT_DIR / "VERSION"
-PYPROJECT_FILE = ROOT_DIR / "pyproject.toml"
-HELM_CHART_FILE = ROOT_DIR / "charts" / "kokoro-fastapi" / "Chart.yaml"
-README_FILE = ROOT_DIR / "README.md"
-# --- End Configuration ---
-
-
-def update_pyproject(version: str):
- """Updates the version in pyproject.toml"""
- if not PYPROJECT_FILE.exists():
- print(f"Skipping: {PYPROJECT_FILE} not found.")
- return
-
- try:
- content = PYPROJECT_FILE.read_text()
- # Regex to find and capture current version = "X.Y.Z" under [project]
- pattern = r'(^\[project\]\s*(?:.*\s)*?version\s*=\s*)"([^"]+)"'
- match = re.search(pattern, content, flags=re.MULTILINE)
-
- if not match:
- print(f"Warning: Version pattern not found in {PYPROJECT_FILE}")
- return
-
- current_version = match.group(2)
- if current_version == version:
- print(f"Already up-to-date: {PYPROJECT_FILE} (version {version})")
- else:
- # Perform replacement
- new_content = re.sub(
- pattern, rf'\1"{version}"', content, count=1, flags=re.MULTILINE
- )
- PYPROJECT_FILE.write_text(new_content)
- print(f"Updated {PYPROJECT_FILE} from {current_version} to {version}")
-
- except Exception as e:
- print(f"Error processing {PYPROJECT_FILE}: {e}")
-
-
-def update_helm_chart(version: str):
- """Updates the version and appVersion in the Helm chart"""
- if not HELM_CHART_FILE.exists():
- print(f"Skipping: {HELM_CHART_FILE} not found.")
- return
-
- try:
- content = HELM_CHART_FILE.read_text()
- original_content = content
- updated_count = 0
-
- # Update 'version:' line (unquoted)
- # Looks for 'version:' followed by optional whitespace and the version number
- version_pattern = r"^(version:\s*)(\S+)"
- current_version_match = re.search(version_pattern, content, flags=re.MULTILINE)
- if current_version_match and current_version_match.group(2) != version:
- content = re.sub(
- version_pattern,
- rf"\g<1>{version}",
- content,
- count=1,
- flags=re.MULTILINE,
- )
- print(
- f"Updating 'version' in {HELM_CHART_FILE} from {current_version_match.group(2)} to {version}"
- )
- updated_count += 1
- elif current_version_match:
- print(f"Already up-to-date: 'version' in {HELM_CHART_FILE} is {version}")
- else:
- print(f"Warning: 'version:' pattern not found in {HELM_CHART_FILE}")
-
- # Update 'appVersion:' line (quoted or unquoted)
- # Looks for 'appVersion:' followed by optional whitespace, optional quote, the version, optional quote
- app_version_pattern = r"^(appVersion:\s*)(\"?)([^\"\s]+)(\"?)"
- current_app_version_match = re.search(
- app_version_pattern, content, flags=re.MULTILINE
- )
-
- if current_app_version_match:
- leading_whitespace = current_app_version_match.group(
- 1
- ) # e.g., "appVersion: "
- opening_quote = current_app_version_match.group(2) # e.g., '"' or ''
- current_app_ver = current_app_version_match.group(3) # e.g., '0.2.0'
- closing_quote = current_app_version_match.group(4) # e.g., '"' or ''
-
- # Check if quotes were consistent (both present or both absent)
- if opening_quote != closing_quote:
- print(
- f"Warning: Inconsistent quotes found for appVersion in {HELM_CHART_FILE}. Skipping update for this line."
- )
- elif (
- current_app_ver == version and opening_quote == '"'
- ): # Check if already correct *and* quoted
- print(
- f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is \"{version}\""
- )
- else:
- # Always replace with the quoted version
- replacement = f'{leading_whitespace}"{version}"' # Ensure quotes
- original_display = f"{opening_quote}{current_app_ver}{closing_quote}" # How it looked before
- target_display = f'"{version}"' # How it should look
-
- # Only report update if the displayed value actually changes
- if original_display != target_display:
- content = re.sub(
- app_version_pattern,
- replacement,
- content,
- count=1,
- flags=re.MULTILINE,
- )
- print(
- f"Updating 'appVersion' in {HELM_CHART_FILE} from {original_display} to {target_display}"
- )
- updated_count += 1
- else:
- # It matches the target version but might need quoting fixed silently if we didn't update
- # Or it was already correct. Check if content changed. If not, report up-to-date.
- if not (
- content != original_content and updated_count > 0
- ): # Avoid double message if version also changed
- print(
- f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is {target_display}"
- )
-
- else:
- print(f"Warning: 'appVersion:' pattern not found in {HELM_CHART_FILE}")
-
- # Write back only if changes were made
- if content != original_content:
- HELM_CHART_FILE.write_text(content)
- # Confirmation message printed above during the specific update
- elif updated_count == 0 and current_version_match and current_app_version_match:
- # If no updates were made but patterns were found, confirm it's up-to-date overall
- print(f"Already up-to-date: {HELM_CHART_FILE} (version {version})")
-
- except Exception as e:
- print(f"Error processing {HELM_CHART_FILE}: {e}")
-
-
-def update_readme(version_with_v: str):
- """Updates Docker image tags in README.md"""
- if not README_FILE.exists():
- print(f"Skipping: {README_FILE} not found.")
- return
-
- try:
- content = README_FILE.read_text()
- # Regex to find and capture current ghcr.io/.../kokoro-fastapi-(cpu|gpu):vX.Y.Z
- pattern = r"(ghcr\.io/remsky/kokoro-fastapi-(?:cpu|gpu)):(v\d+\.\d+\.\d+)"
- matches = list(re.finditer(pattern, content)) # Find all occurrences
-
- if not matches:
- print(f"Warning: Docker image tag pattern not found in {README_FILE}")
- else:
- updated_needed = False
- for match in matches:
- current_tag = match.group(2)
- if current_tag != version_with_v:
- updated_needed = True
- break # Only need one mismatch to trigger update
-
- if updated_needed:
- # Perform replacement on all occurrences
- new_content = re.sub(pattern, rf"\1:{version_with_v}", content)
- README_FILE.write_text(new_content)
- print(f"Updated Docker image tags in {README_FILE} to {version_with_v}")
- else:
- print(
- f"Already up-to-date: Docker image tags in {README_FILE} (version {version_with_v})"
- )
-
- # Check for ':latest' tag usage remains the same
- if ":latest" in content:
- print(
- f"Warning: Found ':latest' tag in {README_FILE}. Consider updating manually if needed."
- )
-
- except Exception as e:
- print(f"Error processing {README_FILE}: {e}")
-
-
-def main():
- # Read the version from the VERSION file
- if not VERSION_FILE.exists():
- print(f"Error: {VERSION_FILE} not found.")
- return
-
- try:
- version = VERSION_FILE.read_text().strip()
- if not re.match(r"^\d+\.\d+\.\d+$", version):
- print(
- f"Error: Invalid version format '{version}' in {VERSION_FILE}. Expected X.Y.Z"
- )
- return
- except Exception as e:
- print(f"Error reading {VERSION_FILE}: {e}")
- return
-
- print(f"Read version: {version} from {VERSION_FILE}")
- print("-" * 20)
-
- # Prepare versions (with and without 'v')
- version_plain = version
- version_with_v = f"v{version}"
-
- # Update files
- update_pyproject(version_plain)
- update_helm_chart(version_plain)
- update_readme(version_with_v)
-
- print("-" * 20)
- print("Version update script finished.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/slim.report.json b/slim.report.json
new file mode 100644
index 00000000..415c3815
--- /dev/null
+++ b/slim.report.json
@@ -0,0 +1,49 @@
+{
+ "document": "doc.report.command",
+ "version": "ov/command/slim/1.1",
+ "engine": "linux/amd64|ALP|x.1.42.2|29e62e7836de7b1004607c51c502537ffe1969f0|2025-01-16_07:48:54AM|x",
+ "containerized": false,
+ "host_distro": {
+ "name": "Ubuntu",
+ "version": "22.04",
+ "display_name": "Ubuntu 22.04.5 LTS"
+ },
+ "type": "slim",
+ "state": "error",
+ "target_reference": "kokoro-fastapi:latest",
+ "system": {
+ "type": "",
+ "release": "",
+ "distro": {
+ "name": "",
+ "version": "",
+ "display_name": ""
+ }
+ },
+ "source_image": {
+ "identity": {
+ "id": ""
+ },
+ "size": 0,
+ "size_human": "",
+ "create_time": "",
+ "architecture": "",
+ "container_entry": {
+ "exe_path": ""
+ }
+ },
+ "minified_image_size": 0,
+ "minified_image_size_human": "",
+ "minified_image": "",
+ "minified_image_id": "",
+ "minified_image_digest": "",
+ "minified_image_has_data": false,
+ "minified_by": 0,
+ "artifact_location": "",
+ "container_report_name": "",
+ "seccomp_profile_name": "",
+ "apparmor_profile_name": "",
+ "image_stack": null,
+ "image_created": false,
+ "image_build_engine": ""
+}
diff --git a/start-cpu.ps1 b/start-cpu.ps1
deleted file mode 100644
index 5a5df265..00000000
--- a/start-cpu.ps1
+++ /dev/null
@@ -1,13 +0,0 @@
-$env:PHONEMIZER_ESPEAK_LIBRARY="C:\Program Files\eSpeak NG\libespeak-ng.dll"
-$env:PYTHONUTF8=1
-$Env:PROJECT_ROOT="$pwd"
-$Env:USE_GPU="false"
-$Env:USE_ONNX="false"
-$Env:PYTHONPATH="$Env:PROJECT_ROOT;$Env:PROJECT_ROOT/api"
-$Env:MODEL_DIR="src/models"
-$Env:VOICES_DIR="src/voices/v1_0"
-$Env:WEB_PLAYER_PATH="$Env:PROJECT_ROOT/web"
-
-uv pip install -e ".[cpu]"
-uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
-uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
\ No newline at end of file
diff --git a/start-cpu.sh b/start-cpu.sh
index 98fae6de..651f645c 100755
--- a/start-cpu.sh
+++ b/start-cpu.sh
@@ -10,17 +10,8 @@ export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
-# Set the espeak-ng data path to your location
-export ESPEAK_DATA_PATH=/usr/lib/x86_64-linux-gnu/espeak-ng-data
# Run FastAPI with CPU extras using uv run
# Note: espeak may still require manual installation,
uv pip install -e ".[cpu]"
-uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
-
-# Apply the misaki patch to fix possible EspeakWrapper issue in older versions
-# echo "Applying misaki patch..."
-# python scripts/fix_misaki.py
-
-# Start the server
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
diff --git a/start-gpu.ps1 b/start-gpu.ps1
deleted file mode 100644
index 7b161a5a..00000000
--- a/start-gpu.ps1
+++ /dev/null
@@ -1,13 +0,0 @@
-$env:PHONEMIZER_ESPEAK_LIBRARY="C:\Program Files\eSpeak NG\libespeak-ng.dll"
-$env:PYTHONUTF8=1
-$Env:PROJECT_ROOT="$pwd"
-$Env:USE_GPU="true"
-$Env:USE_ONNX="false"
-$Env:PYTHONPATH="$Env:PROJECT_ROOT;$Env:PROJECT_ROOT/api"
-$Env:MODEL_DIR="src/models"
-$Env:VOICES_DIR="src/voices/v1_0"
-$Env:WEB_PLAYER_PATH="$Env:PROJECT_ROOT/web"
-
-uv pip install -e ".[gpu]"
-uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
-uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
\ No newline at end of file
diff --git a/start-gpu.sh b/start-gpu.sh
index a3a2e687..b0799788 100755
--- a/start-gpu.sh
+++ b/start-gpu.sh
@@ -12,7 +12,5 @@ export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
# Run FastAPI with GPU extras using uv run
-# Note: espeak may still require manual installation,
uv pip install -e ".[gpu]"
-uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
diff --git a/start-gpu_mac.sh b/start-gpu_mac.sh
deleted file mode 100755
index 9d00063d..00000000
--- a/start-gpu_mac.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/bin/bash
-
-# Get project root directory
-PROJECT_ROOT=$(pwd)
-
-# Set other environment variables
-export USE_GPU=true
-export USE_ONNX=false
-export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
-export MODEL_DIR=src/models
-export VOICES_DIR=src/voices/v1_0
-export WEB_PLAYER_PATH=$PROJECT_ROOT/web
-
-export DEVICE_TYPE=mps
-# Enable MPS fallback for unsupported operations
-export PYTORCH_ENABLE_MPS_FALLBACK=1
-
-# Run FastAPI with GPU extras using uv run
-uv pip install -e .
-uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
-uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
diff --git a/ui/depr_tests/conftest.py b/ui/depr_tests/conftest.py
index 3a65b691..f0c2a2ef 100644
--- a/ui/depr_tests/conftest.py
+++ b/ui/depr_tests/conftest.py
@@ -1,6 +1,5 @@
-from unittest.mock import AsyncMock, Mock
-
import pytest
+from unittest.mock import AsyncMock, Mock
from api.src.services.tts_service import TTSService
@@ -31,22 +30,17 @@ async def mock_tts_service(mock_model_manager, mock_voice_manager):
@pytest.fixture(autouse=True)
-async def setup_mocks(
- monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service
-):
+async def setup_mocks(monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service):
"""Setup global mocks for UI tests"""
-
async def mock_get_model():
return mock_model_manager
-
+
async def mock_get_voice():
return mock_voice_manager
-
+
async def mock_create_service():
return mock_tts_service
-
+
monkeypatch.setattr("api.src.inference.model_manager.get_manager", mock_get_model)
monkeypatch.setattr("api.src.inference.voice_manager.get_manager", mock_get_voice)
- monkeypatch.setattr(
- "api.src.services.tts_service.TTSService.create", mock_create_service
- )
+ monkeypatch.setattr("api.src.services.tts_service.TTSService.create", mock_create_service)
diff --git a/ui/depr_tests/test_api.py b/ui/depr_tests/test_api.py
index 37157f02..d6823268 100644
--- a/ui/depr_tests/test_api.py
+++ b/ui/depr_tests/test_api.py
@@ -1,4 +1,4 @@
-from unittest.mock import mock_open, patch
+from unittest.mock import patch, mock_open
import pytest
import requests
@@ -59,11 +59,9 @@ def test_check_api_status_connection_error():
def test_text_to_speech_success(mock_response, tmp_path):
"""Test successful speech generation"""
- with (
- patch("requests.post", return_value=mock_response({})),
- patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
- patch("builtins.open", mock_open()) as mock_file,
- ):
+ with patch("requests.post", return_value=mock_response({})), patch(
+ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
+ ), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
assert result is not None
@@ -118,11 +116,9 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
]
for input_voice, expected_voice in test_cases:
- with (
- patch("requests.post") as mock_post,
- patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
- patch("builtins.open", mock_open()),
- ):
+ with patch("requests.post") as mock_post, patch(
+ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
+ ), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({})
api.text_to_speech("test text", input_voice, "mp3", 1.5)
@@ -153,15 +149,11 @@ def test_text_to_speech_output_filename(mock_response, tmp_path):
]
for input_voice, filename_check in test_cases:
- with (
- patch("requests.post", return_value=mock_response({})),
- patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
- patch("builtins.open", mock_open()) as mock_file,
- ):
+ with patch("requests.post", return_value=mock_response({})), patch(
+ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
+ ), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
assert result is not None
- assert filename_check(result), (
- f"Expected voice pattern not found in filename: {result}"
- )
+ assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
mock_file.assert_called_once()
diff --git a/ui/depr_tests/test_components.py b/ui/depr_tests/test_components.py
index ddd831b8..9e2b796e 100644
--- a/ui/depr_tests/test_components.py
+++ b/ui/depr_tests/test_components.py
@@ -1,9 +1,9 @@
import gradio as gr
import pytest
+from ui.lib.config import AUDIO_FORMATS
from ui.lib.components.model import create_model_column
from ui.lib.components.output import create_output_column
-from ui.lib.config import AUDIO_FORMATS
def test_create_model_column_structure():
diff --git a/ui/depr_tests/test_files.py b/ui/depr_tests/test_files.py
index 30be2931..2e7e0389 100644
--- a/ui/depr_tests/test_files.py
+++ b/ui/depr_tests/test_files.py
@@ -15,9 +15,8 @@ def mock_dirs(tmp_path):
inputs_dir.mkdir()
outputs_dir.mkdir()
- with (
- patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)),
- patch("ui.lib.files.OUTPUTS_DIR", str(outputs_dir)),
+ with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
+ "ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
):
yield inputs_dir, outputs_dir
diff --git a/ui/depr_tests/test_interface.py b/ui/depr_tests/test_interface.py
index d9c49629..15c60ba3 100644
--- a/ui/depr_tests/test_interface.py
+++ b/ui/depr_tests/test_interface.py
@@ -62,9 +62,8 @@ def test_interface_html_links():
def test_update_status_available(mock_timer):
"""Test status update when service is available"""
voices = ["voice1", "voice2"]
- with (
- patch("ui.lib.api.check_api_status", return_value=(True, voices)),
- patch("gradio.Timer", return_value=mock_timer),
+ with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
+ "gradio.Timer", return_value=mock_timer
):
demo = create_interface()
@@ -82,9 +81,8 @@ def test_update_status_available(mock_timer):
def test_update_status_unavailable(mock_timer):
"""Test status update when service is unavailable"""
- with (
- patch("ui.lib.api.check_api_status", return_value=(False, [])),
- patch("gradio.Timer", return_value=mock_timer),
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
+ "gradio.Timer", return_value=mock_timer
):
demo = create_interface()
update_fn = mock_timer.events[0].fn
@@ -99,10 +97,9 @@ def test_update_status_unavailable(mock_timer):
def test_update_status_error(mock_timer):
"""Test status update when an error occurs"""
- with (
- patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")),
- patch("gradio.Timer", return_value=mock_timer),
- ):
+ with patch(
+ "ui.lib.api.check_api_status", side_effect=Exception("Test error")
+ ), patch("gradio.Timer", return_value=mock_timer):
demo = create_interface()
update_fn = mock_timer.events[0].fn
@@ -116,9 +113,8 @@ def test_update_status_error(mock_timer):
def test_timer_configuration(mock_timer):
"""Test timer configuration"""
- with (
- patch("ui.lib.api.check_api_status", return_value=(False, [])),
- patch("gradio.Timer", return_value=mock_timer),
+ with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
+ "gradio.Timer", return_value=mock_timer
):
demo = create_interface()
diff --git a/ui/lib/api.py b/ui/lib/api.py
index 8bb8b87c..ca0d7e8b 100644
--- a/ui/lib/api.py
+++ b/ui/lib/api.py
@@ -1,6 +1,6 @@
-import datetime
import os
-from typing import List, Optional, Tuple
+import datetime
+from typing import List, Tuple, Optional
import requests
diff --git a/ui/lib/components/input.py b/ui/lib/components/input.py
index b830b568..a2c4d336 100644
--- a/ui/lib/components/input.py
+++ b/ui/lib/components/input.py
@@ -11,10 +11,12 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
text_input = gr.Textbox(
label="Text to speak", placeholder="Enter text here...", lines=4
)
-
+
# Always show file upload but handle differently based on disable_local_saving
- file_upload = gr.File(label="Upload Text File (.txt)", file_types=[".txt"])
-
+ file_upload = gr.File(
+ label="Upload Text File (.txt)", file_types=[".txt"]
+ )
+
if not disable_local_saving:
# Show full interface with tabs when saving is enabled
with gr.Tabs() as tabs:
@@ -22,9 +24,7 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
tabs.selected = 0
# Direct Input Tab
with gr.TabItem("Direct Input"):
- text_submit_direct = gr.Button(
- "Generate Speech", variant="primary", size="lg"
- )
+ text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
# File Input Tab
with gr.TabItem("From File"):
@@ -48,9 +48,7 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
)
else:
# Just show the generate button when saving is disabled
- text_submit_direct = gr.Button(
- "Generate Speech", variant="primary", size="lg"
- )
+ text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
tabs = None
input_files_list = None
file_preview = None
diff --git a/ui/lib/components/model.py b/ui/lib/components/model.py
index d3426bc8..a659d2c1 100644
--- a/ui/lib/components/model.py
+++ b/ui/lib/components/model.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from typing import Tuple, Optional
import gradio as gr
diff --git a/ui/lib/components/output.py b/ui/lib/components/output.py
index 5e7412cd..083829e0 100644
--- a/ui/lib/components/output.py
+++ b/ui/lib/components/output.py
@@ -12,7 +12,7 @@ def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
audio_output = gr.Audio(
label="Generated Speech",
type="filepath",
- waveform_options={"waveform_color": "#4C87AB"},
+ waveform_options={"waveform_color": "#4C87AB"}
)
# Create file-related components with visible=False when local saving is disabled
@@ -26,14 +26,14 @@ def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
)
play_btn = gr.Button(
- "▶️ Play Selected",
+ "▶️ Play Selected",
size="sm",
visible=not disable_local_saving,
)
selected_audio = gr.Audio(
- label="Selected Output",
- type="filepath",
+ label="Selected Output",
+ type="filepath",
visible=False, # Always initially hidden
)
diff --git a/ui/lib/files.py b/ui/lib/files.py
index f79b88fa..1391e0ac 100644
--- a/ui/lib/files.py
+++ b/ui/lib/files.py
@@ -1,8 +1,8 @@
-import datetime
import os
-from typing import List, Optional, Tuple
+import datetime
+from typing import List, Tuple, Optional
-from .config import AUDIO_FORMATS, INPUTS_DIR, OUTPUTS_DIR
+from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
def list_input_files() -> List[str]:
diff --git a/ui/lib/handlers.py b/ui/lib/handlers.py
index 224f6509..71b8d9b4 100644
--- a/ui/lib/handlers.py
+++ b/ui/lib/handlers.py
@@ -58,21 +58,17 @@ def handle_file_select(filename):
def handle_file_upload(file):
if file is None:
- return (
- ""
- if disable_local_saving
- else [gr.update(choices=files.list_input_files())]
- )
+ return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
try:
# Read the file content
- with open(file.name, "r", encoding="utf-8") as f:
+ with open(file.name, 'r', encoding='utf-8') as f:
text_content = f.read()
if disable_local_saving:
# When saving is disabled, put content directly in text input
# Normalize whitespace by replacing newlines with spaces
- normalized_text = " ".join(text_content.split())
+ normalized_text = ' '.join(text_content.split())
return normalized_text
else:
# When saving is enabled, save file and update dropdown
@@ -92,11 +88,7 @@ def handle_file_upload(file):
except Exception as e:
print(f"Error handling file: {e}")
- return (
- ""
- if disable_local_saving
- else [gr.update(choices=files.list_input_files())]
- )
+ return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
def generate_from_text(text, voice, format, speed):
"""Generate speech from direct text input"""
@@ -112,7 +104,7 @@ def generate_from_text(text, voice, format, speed):
# Only save text if local saving is enabled
if not disable_local_saving:
files.save_text(text)
-
+
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
@@ -211,11 +203,7 @@ def clear_outputs():
components["input"]["file_upload"].upload(
fn=handle_file_upload,
inputs=[components["input"]["file_upload"]],
- outputs=[
- components["input"]["text_input"]
- if disable_local_saving
- else components["input"]["file_select"]
- ],
+ outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
)
if components["output"]["play_btn"] is not None:
diff --git a/ui/lib/interface.py b/ui/lib/interface.py
index b35bee8e..1ae344ca 100644
--- a/ui/lib/interface.py
+++ b/ui/lib/interface.py
@@ -1,10 +1,9 @@
-import os
-
import gradio as gr
+import os
from . import api
-from .components import create_input_column, create_model_column, create_output_column
from .handlers import setup_event_handlers
+from .components import create_input_column, create_model_column, create_output_column
def create_interface():
diff --git a/web/index.html b/web/index.html
index 3f9db97c..eb998169 100644
--- a/web/index.html
+++ b/web/index.html
@@ -16,7 +16,7 @@
-
+
diff --git a/web/src/services/AudioService.js b/web/src/services/AudioService.js
index 6e0e90b4..cee33d46 100644
--- a/web/src/services/AudioService.js
+++ b/web/src/services/AudioService.js
@@ -264,46 +264,19 @@ export class AudioService {
// Don't process if audio is in error state
if (this.audio.error) {
- console.warn("Skipping operation due to audio error");
+ console.warn('Skipping operation due to audio error');
return;
}
const operation = this.pendingOperations.shift();
-
+
try {
this.sourceBuffer.appendBuffer(operation.chunk);
-
- // Set up event listeners
- const onUpdateEnd = () => {
- operation.resolve();
- this.sourceBuffer.removeEventListener("updateend", onUpdateEnd);
- this.sourceBuffer.removeEventListener(
- "updateerror",
- onUpdateError
- );
- // Process the next operation
- this.processNextOperation();
- };
-
- const onUpdateError = (event) => {
- operation.reject(event);
- this.sourceBuffer.removeEventListener("updateend", onUpdateEnd);
- this.sourceBuffer.removeEventListener(
- "updateerror",
- onUpdateError
- );
- // Decide whether to continue processing
- if (event.name !== "InvalidStateError") {
- this.processNextOperation();
- }
- };
-
- this.sourceBuffer.addEventListener("updateend", onUpdateEnd);
- this.sourceBuffer.addEventListener("updateerror", onUpdateError);
+ operation.resolve();
} catch (error) {
operation.reject(error);
// Only continue processing if it's not a fatal error
- if (error.name !== "InvalidStateError") {
+ if (error.name !== 'InvalidStateError') {
this.processNextOperation();
}
}
@@ -391,14 +364,14 @@ export class AudioService {
this.controller.abort();
this.controller = null;
}
-
+
if (this.audio) {
this.audio.pause();
- this.audio.src = "";
+ this.audio.src = '';
this.audio = null;
}
- if (this.mediaSource && this.mediaSource.readyState === "open") {
+ if (this.mediaSource && this.mediaSource.readyState === 'open') {
try {
this.mediaSource.endOfStream();
} catch (e) {
@@ -407,11 +380,7 @@ export class AudioService {
}
this.mediaSource = null;
- if (this.sourceBuffer) {
- this.sourceBuffer.removeEventListener("updateend", () => {});
- this.sourceBuffer.removeEventListener("updateerror", () => {});
- this.sourceBuffer = null;
- }
+ this.sourceBuffer = null;
this.serverDownloadPath = null;
this.pendingOperations = [];
}
@@ -419,17 +388,17 @@ export class AudioService {
cleanup() {
if (this.audio) {
this.eventListeners.forEach((listeners, event) => {
- listeners.forEach((callback) => {
+ listeners.forEach(callback => {
this.audio.removeEventListener(event, callback);
});
});
-
+
this.audio.pause();
- this.audio.src = "";
+ this.audio.src = '';
this.audio = null;
}
- if (this.mediaSource && this.mediaSource.readyState === "open") {
+ if (this.mediaSource && this.mediaSource.readyState === 'open') {
try {
this.mediaSource.endOfStream();
} catch (e) {
@@ -438,11 +407,7 @@ export class AudioService {
}
this.mediaSource = null;
- if (this.sourceBuffer) {
- this.sourceBuffer.removeEventListener("updateend", () => {});
- this.sourceBuffer.removeEventListener("updateerror", () => {});
- this.sourceBuffer = null;
- }
+ this.sourceBuffer = null;
this.serverDownloadPath = null;
this.pendingOperations = [];
}