diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 00000000..1764128c --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,15 @@ +{ + "permissions": { + "allow": [ + "WebFetch(domain:github.com)", + "Bash(chmod:*)", + "Bash(rg:*)", + "Bash(grep:*)", + "Bash(ruff check:*)", + "Bash(python:*)", + "Bash(rm:*)", + "Bash(sed:*)" + ], + "deny": [] + } +} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 14466656..c57815a9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build/ dist/ wheels/ *.egg-info +.claude # Virtual environments .venv @@ -18,4 +19,7 @@ wheels/ .ruff_cache .ipynb_checkpoints -config.json \ No newline at end of file +config.json +CLAUDE.md +CLAUDE.md +.claude/settings.local.json diff --git a/API_INTEGRATION.md b/API_INTEGRATION.md new file mode 100644 index 00000000..103b4c50 --- /dev/null +++ b/API_INTEGRATION.md @@ -0,0 +1,707 @@ +# Dia TTS API Integration Guide + +## What's New + +- **Role-Based Input**: Automatically map chat roles (user/assistant/system) to speaker tags +- **Improved Voice Cloning**: Now follows Dia's reference implementation exactly +- **Persistent Audio Prompts**: Audio files are stored on disk and persist across restarts +- **Audio Prompt Transcripts**: Include transcripts for significantly better voice cloning + +### Migration Notes + +If you were using a previous version: +1. Audio prompts are now stored as files in `audio_prompts/` directory +2. The `/audio_prompts/re-encode` endpoint has been removed (no longer needed) +3. Voice mappings now support `audio_prompt_transcript` field for better results +4. The `role` parameter is now available for automatic speaker tag mapping + +## Overview + +Dia TTS FastAPI server provides a comprehensive text-to-speech API with advanced features: +- **Voice cloning** with audio prompts (following Dia's reference implementation) +- **Role-based input** for chat applications (user/assistant/system) +- **Configurable model parameters** for quality control +- **Debug logging** and file management +- **Custom voice creation** and management +- **SillyTavern integration** support +- **Persistent audio prompt storage** across server restarts + +## Quick Start + +### 1. Start the Server +```bash +# Recommended: Use startup script (includes environment check) +python start_server.py + +# Development mode (debug + save outputs + prompts + reload) +python start_server.py --dev + +# Production mode (optimized settings) +python start_server.py --production + +# Custom configuration +python start_server.py --debug --save-outputs --workers 8 --retention-hours 48 + +# Direct server launch (bypass startup script) +python fastapi_server.py --debug --save-outputs --port 7860 +``` + +### 2. Basic TTS Request +```bash +# Synchronous request (immediate response) +curl -X POST "http://localhost:7860/generate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Hello, this is a test message.", + "voice_id": "alloy", + "role": "user", + "response_format": "wav", + "temperature": 1.0, + "cfg_scale": 2.5 + }' \ + --output speech.wav + +# With role-based input (new!) +curl -X POST "http://localhost:7860/generate" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "I can help you with that request.", + "voice_id": "nova", + "role": "assistant" + }' \ + --output assistant_speech.wav + +# Asynchronous request (job-based) +curl -X POST "http://localhost:7860/generate?async_mode=true" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Hello, this is a test message.", + "voice_id": "alloy", + "role": "user" + }' +``` + +## Core API Endpoints + +### Text-to-Speech +- **POST** `/generate` - Generate speech from text (with model parameters) +- **GET** `/models` - List available models +- **GET** `/voices` - List available voices + +### Voice Management +- **GET** `/voice_mappings` - List voice configurations +- **POST** `/voice_mappings` - Create custom voice +- **PUT** `/voice_mappings/{voice_id}` - Update voice +- **DELETE** `/voice_mappings/{voice_id}` - Delete custom voice + +### Audio Prompts +- **POST** `/audio_prompts/upload` - Upload voice sample (saved as WAV file) +- **GET** `/audio_prompts` - List uploaded prompts with file info +- **DELETE** `/audio_prompts/{prompt_id}` - Delete prompt and file + +### Debug & Configuration +- **GET** `/config` - Get server configuration +- **PUT** `/config` - Update server configuration +- **GET** `/logs` - List generation logs +- **GET** `/logs/{id}/download` - Download saved audio +- **POST** `/cleanup` - Manual file cleanup + +### Job Queue Management +- **GET** `/jobs` - List jobs with status filtering +- **GET** `/jobs/{id}` - Get job status and details +- **GET** `/jobs/{id}/result` - Download completed job result +- **DELETE** `/jobs/{id}` - Cancel pending job +- **GET** `/queue/stats` - Get queue statistics +- **DELETE** `/jobs` - Clear completed jobs + +## SillyTavern Integration + +### Configuration +1. Navigate to **Settings → Text-to-Speech** +2. Set **TTS Provider**: `Custom` +3. Set **Endpoint URL**: `http://localhost:7860/generate` +4. Choose **Voice**: Built-in (`alloy`, `echo`, `fable`, `nova`, `onyx`, `shimmer`) or custom voices + +### Advanced Features +- **Custom Voices**: Upload audio samples and create character-specific voices +- **Model Parameters**: SillyTavern may pass through additional parameters +- **Debug Mode**: Enable server debug logging to troubleshoot issues +- **Role Support**: API automatically maps chat roles to appropriate speakers + +### Role-Based Generation Example +When SillyTavern sends requests with roles, they are automatically mapped: +```json +// User message (maps to [S1]) +{ + "text": "Tell me a story about dragons.", + "voice_id": "alloy", + "role": "user" +} + +// Character/Assistant response (maps to [S2]) +{ + "text": "Once upon a time, in a land far away...", + "voice_id": "nova", + "role": "assistant" +} +``` + +### Custom Voice Setup +```bash +# 1. Upload voice sample +curl -X POST "http://localhost:7860/audio_prompts/upload" \ + -F "prompt_id=character_voice" \ + -F "audio_file=@sample.wav" + +# 2. Create voice mapping with transcript +curl -X POST "http://localhost:7860/voice_mappings" \ + -H "Content-Type: application/json" \ + -d '{ + "voice_id": "my_character", + "style": "expressive", + "primary_speaker": "S1", + "audio_prompt": "character_voice", + "audio_prompt_transcript": "[S1] Sample text from the character." + }' + +# 3. Use in SillyTavern by setting Voice: "my_character" +``` + +## Request/Response Formats + +### TTS Request +```json +{ + "text": "Your text here", + "voice_id": "alloy", + "response_format": "wav", + "speed": 1.0, + "role": "user", // Optional: "user", "assistant", or "system" + "temperature": 1.2, + "cfg_scale": 3.0, + "top_p": 0.95, + "max_tokens": 2000, + "use_torch_compile": true +} +``` + +### Voice Mapping +```json +{ + "voice_id": "custom_voice", + "style": "neutral", + "primary_speaker": "S1", + "audio_prompt": "prompt_id", + "audio_prompt_transcript": "[S1] This is what I said in the audio sample." +} +``` + +## Text Format Guidelines + +### Role-Based Input (NEW) +The API now supports role-based text formatting for easier integration with chat applications: +- **`role: "user"`** - Maps to `[S1]` speaker tag +- **`role: "assistant"`** - Maps to `[S2]` speaker tag +- **`role: "system"`** - Maps to `[S2]` speaker tag + +Example with roles: +```json +// User speaking +{ + "text": "What's the weather like today?", + "voice_id": "alloy", + "role": "user" +} + +// Assistant responding +{ + "text": "The weather is sunny with a high of 75 degrees.", + "voice_id": "nova", + "role": "assistant" +} +``` + +### Speaker Tags (Manual) +- Always start with `[S1]` or `[S2]` +- Alternate speakers: `[S1] Hello! [S2] Hi there!` +- Single speaker: `[S1] This is a monologue.` + +### Nonverbal Sounds +- `(laughs)`, `(coughs)`, `(sighs)`, `(gasps)` +- `(clears throat)`, `(whispers)`, `(shouts)` + +### Best Practices +- Keep input under 4096 characters +- Use clear, natural speech patterns +- For voice cloning: provide 3-10 seconds of clean audio + +## Error Handling + +### Common Status Codes +- `400` - Invalid request format +- `404` - Voice/prompt not found +- `500` - Generation or server error + +### Example Error Response +```json +{ + "detail": "Voice 'unknown_voice' not found" +} +``` + +## Asynchronous Processing + +### Job-Based Workflow +```bash +# 1. Submit job +response=$(curl -X POST "http://localhost:7860/generate?async_mode=true" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "[S1] Hello world!", + "voice_id": "alloy" + }') + +job_id=$(echo $response | jq -r '.job_id') + +# 2. Check job status +curl "http://localhost:7860/jobs/$job_id" + +# 3. Download result when completed +curl "http://localhost:7860/jobs/$job_id/result" -o output.wav + +# 4. Get queue statistics +curl "http://localhost:7860/queue/stats" +``` + +### Job Management +```bash +# List all jobs +curl "http://localhost:7860/jobs" + +# List only pending jobs +curl "http://localhost:7860/jobs?status=pending" + +# Cancel a pending job +curl -X DELETE "http://localhost:7860/jobs/$job_id" + +# Clear completed jobs +curl -X DELETE "http://localhost:7860/jobs" +``` + +## Advanced Configuration + +### Server Configuration +```bash +# Enable debug features +curl -X PUT "http://localhost:7860/config" \ + -H "Content-Type: application/json" \ + -d '{ + "debug_mode": true, + "save_outputs": true, + "show_prompts": true, + "output_retention_hours": 48 + }' + +# Get current configuration +curl "http://localhost:7860/config" +``` + +### Generation Monitoring +```bash +# View recent generations +curl "http://localhost:7860/logs?limit=10" + +# Download specific generation +curl "http://localhost:7860/logs/{log_id}/download" -o output.wav + +# Clear all logs +curl -X DELETE "http://localhost:7860/logs" +``` + +## Voice Cloning Workflow (Enhanced) + +### 1. Prepare Audio Sample +- **Formats**: WAV, MP3, M4A, FLAC, OGG, AAC (auto-converted to WAV) +- **Duration**: 0.5-60 seconds (3-10 seconds recommended) +- **Quality**: Clear speech, minimal background noise +- **Content**: Speaker saying target text or similar phrases +- **Storage**: Audio files are saved in `audio_prompts/` directory as WAV files +- **Persistence**: Audio prompts persist across server restarts (loaded on startup) + +### 2. Upload and Configure +```python +import requests + +# Upload audio prompt (saved as WAV file) +with open("voice_sample.wav", "rb") as f: + response = requests.post( + "http://localhost:7860/audio_prompts/upload", + data={"prompt_id": "my_voice"}, + files={"audio_file": f} + ) + +# Create voice mapping WITH transcript for better results +# IMPORTANT: The transcript should match what's in the audio file +requests.post( + "http://localhost:7860/voice_mappings", + json={ + "voice_id": "cloned_voice", + "style": "natural", + "primary_speaker": "S1", + "audio_prompt": "my_voice", + "audio_prompt_transcript": "[S1] This is what I said in the audio sample." + } +) + +# The model will: +# 1. Load the audio file directly +# 2. Prepend the transcript to your text +# 3. Generate audio only for the new text portion +``` + +### 3. Generate Speech +```python +response = requests.post( + "http://localhost:7860/generate", + json={ + "voice_id": "alloy", + "text": "[S1] Hello, this is my cloned voice!", + "voice_id": "cloned_voice", + "temperature": 1.1, + "cfg_scale": 2.8, + "top_p": 0.9 + } +) + +with open("output.wav", "wb") as f: + f.write(response.content) +``` + +### How Voice Cloning Works + +Based on the Dia model's reference implementation: + +1. **Audio Prompt**: The model loads your audio file directly +2. **Transcript Concatenation**: Your audio transcript is prepended to the target text +3. **Generation**: The model generates audio for the ENTIRE concatenated text +4. **Output**: Only the audio for the new text (after transcript) is returned + +Example flow: +``` +Audio prompt transcript: "[S1] Hi, this is my voice sample." +Target text: "[S1] Hello world!" +Model input: "[S1] Hi, this is my voice sample. [S1] Hello world!" +Generated output: Audio for "Hello world!" in the cloned voice +``` + +**Best Practices**: +- Ensure transcript exactly matches the audio content +- Use the same speaker tag ([S1] or [S2]) consistently +- Keep audio samples clear and noise-free +- 3-10 seconds of speech works best + +### Complete Voice Cloning Example with Roles + +```python +import requests + +# 1. Upload character voice sample +with open("character_voice.wav", "rb") as f: + requests.post( + "http://localhost:7860/audio_prompts/upload", + data={"prompt_id": "my_character"}, + files={"audio_file": f} + ) + +# 2. Create voice mapping with transcript +requests.post( + "http://localhost:7860/voice_mappings", + json={ + "voice_id": "ai_assistant", + "style": "friendly", + "primary_speaker": "S2", # Assistant uses S2 + "audio_prompt": "my_character", + "audio_prompt_transcript": "[S2] Hello, I am your AI assistant." + } +) + +# 3. Generate responses with role-based input +# User question (uses S1 automatically) +user_response = requests.post( + "http://localhost:7860/generate", + json={ + "text": "Can you help me with Python?", + "voice_id": "alloy", + "role": "user" # Maps to S1 + } +) + +# AI assistant response (uses cloned voice) +assistant_response = requests.post( + "http://localhost:7860/generate", + json={ + "text": "Of course! I'd be happy to help you with Python programming.", + "voice_id": "ai_assistant", + "role": "assistant" # Maps to S2 + } +) +``` + +## Integration Examples + +### Python Client +```python +import requests + +class DiaTTSClient: + def __init__(self, base_url="http://localhost:7860"): + self.base_url = base_url + + def generate_speech(self, text, voice="alloy", **kwargs): + payload = { + "voice_id": "alloy", + "text": text, + "voice_id": voice + } + # Add any additional parameters + payload.update(kwargs) + + response = requests.post( + f"{self.base_url}/generate", + json=payload + ) + return response.content + + def upload_voice(self, prompt_id, audio_file_path): + with open(audio_file_path, "rb") as f: + return requests.post( + f"{self.base_url}/audio_prompts/upload", + data={"prompt_id": prompt_id}, + files={"audio_file": f} + ).json() + +# Usage +client = DiaTTSClient() +audio = client.generate_speech( + "[S1] Hello world!", + "alloy", + temperature=1.1, + cfg_scale=2.5, + top_p=0.9 +) +``` + +### JavaScript/Node.js +```javascript +const FormData = require('form-data'); +const fs = require('fs'); + +class DiaTTSClient { + constructor(baseUrl = 'http://localhost:7860') { + this.baseUrl = baseUrl; + } + + async generateSpeech(text, voice = 'alloy', options = {}) { + const payload = { + model: 'dia', + input: text, + voice: voice, + ...options + }; + + const response = await fetch(`${this.baseUrl}/generate`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload) + }); + return response.arrayBuffer(); + } + + async generateSpeechAsync(text, voice = 'alloy', options = {}) { + const payload = { + model: 'dia', + input: text, + voice: voice, + ...options + }; + + // Submit job + const response = await fetch(`${this.baseUrl}/generate?async_mode=true`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload) + }); + + const { job_id } = await response.json(); + + // Poll for completion + while (true) { + const statusResponse = await fetch(`${this.baseUrl}/jobs/${job_id}`); + const job = await statusResponse.json(); + + if (job.status === 'completed') { + const resultResponse = await fetch(`${this.baseUrl}/jobs/${job_id}/result`); + return resultResponse.arrayBuffer(); + } else if (job.status === 'failed') { + throw new Error(`Job failed: ${job.error_message}`); + } + + // Wait 1 second before polling again + await new Promise(resolve => setTimeout(resolve, 1000)); + } + } + + async getQueueStats() { + const response = await fetch(`${this.baseUrl}/queue/stats`); + return response.json(); + } +} + +// Usage examples +const client = new DiaTTSClient(); + +// Synchronous generation +const syncAudio = await client.generateSpeech('[S1] Hello world!', 'alloy', { + temperature: 1.0, + cfg_scale: 2.5 +}); + +// Asynchronous generation +const asyncAudio = await client.generateSpeechAsync('[S1] Long text...', 'nova'); + +// Monitor queue +const stats = await client.getQueueStats(); +console.log(`Pending jobs: ${stats.pending_jobs}, Active workers: ${stats.active_workers}`); +``` + +## Model Parameters + +### Available Parameters +- **temperature** (0.1-2.0): Controls randomness in generation. Higher = more creative, lower = more consistent +- **cfg_scale** (1.0-10.0): Classifier-free guidance scale. Higher = stronger conditioning on input text +- **top_p** (0.0-1.0): Nucleus sampling threshold. Lower = more focused sampling +- **max_tokens** (100-10000): Maximum number of tokens to generate +- **use_torch_compile** (boolean): Enable PyTorch compilation for faster inference + +### Default Values +- temperature: 1.2 +- cfg_scale: 3.0 +- top_p: 0.95 +- max_tokens: auto (based on text length) +- use_torch_compile: auto-detected + +### Parameter Effects +- **Lower temperature + higher cfg_scale**: More consistent, text-faithful speech +- **Higher temperature + lower cfg_scale**: More expressive, natural-sounding speech +- **Lower top_p**: More focused vocabulary, clearer pronunciation +- **Higher top_p**: More varied expressions and natural speech patterns + +### Recommended Settings +```json +// High quality, consistent output +{ + "temperature": 0.8, + "cfg_scale": 4.0, + "top_p": 0.8 +} + +// Natural, expressive speech +{ + "temperature": 1.4, + "cfg_scale": 2.5, + "top_p": 0.95 +} + +// Fast generation +{ + "temperature": 1.0, + "cfg_scale": 2.0, + "max_tokens": 1000, + "use_torch_compile": true +} +``` + +## File Storage Structure + +``` +dia/ +├── audio_prompts/ # Uploaded voice samples (persistent) +│ ├── my_voice.wav +│ ├── character_voice.wav +│ └── ... +├── audio_outputs/ # Generated audio (when save_outputs=true) +│ ├── 20240115_123456_abc123_alloy.wav +│ └── ... +└── fastapi_server.py +``` + +## Performance Considerations + +- **Model Loading**: ~30 seconds on first request +- **Generation Speed**: ~2-5x real-time on GPU +- **Memory Usage**: ~10GB VRAM (GPU) or ~16GB RAM (CPU) +- **Concurrent Requests**: Up to 4 workers by default (configurable) +- **Parameter Impact**: Higher temperature/top_p may increase generation time slightly +- **Queue Benefits**: Non-blocking request handling, better resource utilization +- **Worker Scaling**: Limited by GPU memory (single model instance shared) + +## Troubleshooting + +### Common Issues +1. **Model not loading**: Check `HF_TOKEN` environment variable +2. **Audio quality**: Ensure input text has proper speaker tags (`[S1]`, `[S2]`) +3. **Slow generation**: Use GPU with CUDA for faster inference +4. **Voice not found**: Check `/voice_mappings` for available voices +5. **Torch compile errors**: Server auto-retries without compilation +6. **File upload failures**: Check audio format and file size limits +7. **Queue full**: Use async mode or check `/queue/stats` for worker availability +8. **Job not found**: Jobs expire after 1 hour, check job list with `/jobs` +9. **Audio prompt not influencing output**: + - Ensure audio file exists in audio_prompts directory + - Include accurate audio prompt transcript (must match audio content) + - The transcript is prepended to your text for conditioning + - Model generates audio only for text after the transcript + +### Debug Commands +```bash +# Health check +curl "http://localhost:7860/health" + +# List available models +curl "http://localhost:7860/models" + +# List available voices +curl "http://localhost:7860/voices" + +# Check voice mappings +curl "http://localhost:7860/voice_mappings" + +# View server configuration +curl "http://localhost:7860/config" + +# Check queue statistics +curl "http://localhost:7860/queue/stats" + +# List current jobs +curl "http://localhost:7860/jobs" + +# Enable debug mode +curl -X PUT "http://localhost:7860/config" \ + -d '{"debug_mode": true}' \ + -H "Content-Type: application/json" +``` + +### Log Analysis +- Monitor server console for detailed error information +- Check generation logs via `/logs` endpoint +- Enable `show_prompts` to see text processing +- Use `save_outputs` to verify audio generation + +### Performance Tuning +- Use `use_torch_compile: false` on Windows or unsupported systems +- Lower `max_tokens` for faster generation +- Reduce `temperature` and `cfg_scale` for simpler processing +- Enable debug logging to identify bottlenecks +- Use async mode (`?async_mode=true`) for concurrent requests +- Monitor queue stats to optimize worker utilization +- Consider increasing worker count for high-throughput scenarios \ No newline at end of file diff --git a/AUDIO_PROMPT_EXPLANATION.md b/AUDIO_PROMPT_EXPLANATION.md new file mode 100644 index 00000000..d574d9f3 --- /dev/null +++ b/AUDIO_PROMPT_EXPLANATION.md @@ -0,0 +1,134 @@ +# How audio_prompt Works in Dia Model + +## Overview + +The `audio_prompt` parameter in the Dia model allows for **voice cloning** - using an audio sample to influence the voice characteristics of the generated speech. This enables the model to mimic the voice, tone, and speaking style from a reference audio file. + +## How It Works + +### 1. Audio Prompt Processing + +When you provide an `audio_prompt` to the `generate()` method: + +1. **Audio Loading**: The audio file is loaded and encoded using the DAC (Descript Audio Codec) model + - Converted to 44.1kHz mono if needed + - Encoded into discrete audio tokens (codebook indices) + +2. **Prefill Phase**: The encoded audio tokens are used to "prefill" the decoder + - These tokens provide context about voice characteristics + - The model learns the speaking style, tone, and voice qualities from these tokens + +3. **Generation**: The model continues generating new audio tokens that match: + - The text content you want to synthesize + - The voice characteristics from the audio prompt + +### 2. Technical Implementation + +From `dia/model.py`: + +```python +def _prepare_audio_prompt(self, audio_prompts: list[torch.Tensor | None]) -> tuple[torch.Tensor, list[int]]: + """Prepares the audio prompt tensor for the decoder. + + - Adds beginning-of-sequence (BOS) token + - Applies delay pattern for multi-channel generation + - Returns prefilled audio tokens and prefill steps + """ +``` + +The audio prompt influences generation through: +- **Decoder Prefill**: Audio prompt tokens are fed to the decoder before text generation begins +- **Cross-Attention**: The decoder attends to both text encoding and prefilled audio context +- **Conditional Generation**: The model generates audio that matches both text and voice style + +### 3. Usage Examples + +#### Basic Voice Cloning +```python +# Load a reference audio file +audio_prompt = model.load_audio("reference_voice.mp3") + +# Generate speech with the cloned voice +output = model.generate( + "[S1] Hello, this is a test.", + audio_prompt=audio_prompt +) +``` + +#### Voice Cloning with Context (Recommended) +```python +# Provide the transcript of the reference audio +clone_text = "[S1] This is what the reference speaker said." +clone_audio = "reference_voice.mp3" + +# Generate new content with the same voice +new_text = "[S1] This is new content in the same voice." +output = model.generate( + clone_text + new_text, + audio_prompt=clone_audio +) +``` + +### 4. FastAPI Server Implementation + +The FastAPI server provides endpoints for voice cloning: + +#### Upload Audio Prompt +```bash +curl -X POST "http://localhost:7860/audio_prompts/upload" \ + -F "prompt_id=my_voice" \ + -F "audio_file=@voice_sample.wav" +``` + +#### Create Voice Mapping +```bash +curl -X POST "http://localhost:7860/voice_mappings" \ + -H "Content-Type: application/json" \ + -d '{ + "voice_id": "custom_voice", + "audio_prompt": "my_voice" + }' +``` + +### 5. Important Considerations + +1. **Audio Quality**: Better quality reference audio produces better voice cloning + - Recommended: 3-10 seconds of clean speech + - Avoid background noise or music + +2. **Transcript Accuracy**: Including the transcript of the reference audio improves results + - The model can better understand the voice-to-text mapping + - Helps maintain consistency between voice and content + +3. **Speaker Consistency**: The audio prompt primarily affects the speaker tagged in the text + - If using `[S1]`, the audio prompt influences Speaker 1's voice + - Multiple speakers can have different audio prompts + +4. **Performance**: Using audio prompts increases generation time slightly + - Additional encoding step for the reference audio + - Larger context window during generation + +## How Voice Characteristics Are Captured + +The model captures several aspects from the audio prompt: + +1. **Voice Timbre**: The unique quality of the voice +2. **Speaking Style**: Pace, rhythm, and articulation patterns +3. **Emotional Tone**: The emotional coloring of the speech +4. **Prosody**: Intonation and stress patterns +5. **Speaker Identity**: Gender, age, and individual voice characteristics + +## Limitations + +- Cannot perfectly clone all voices (especially very unique voices) +- Works best with voices similar to those in training data +- Requires clean, high-quality reference audio +- The generated voice may not be 100% identical but will have similar characteristics + +## Best Practices + +1. **Use High-Quality Audio**: Clean recordings without background noise +2. **Match Content Style**: Reference audio should have similar speaking style to target +3. **Appropriate Length**: 3-10 seconds is optimal for most use cases +4. **Include Transcript**: Always provide the transcript of reference audio when possible +5. **Test Different Samples**: Try multiple reference samples to find the best match \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..300c7119 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,394 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dia is a 1.6B parameter text-to-speech model that directly generates realistic dialogue from text transcripts. The model supports emotion/tone control through audio prompts and can produce nonverbal communications like laughter, coughing, etc. + +## Development Commands + +### Environment Setup +```bash +# Using uv (recommended) +uv run app.py + +# Or traditional Python +python -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +pip install -e . +``` + +### Running the Application +```bash +# Start Gradio UI +python app.py +# With custom device: python app.py --device cuda +# With sharing: python app.py --share + +# CLI usage +python cli.py "your text here" --output output.wav + +# FastAPI Server (SillyTavern compatible) +# Simple startup with environment check +python start_server.py + +# Development mode (debug + save outputs + prompts + reload) +python start_server.py --dev + +# Production mode (optimized for deployment) +python start_server.py --production + +# Custom configuration +python start_server.py --debug --save-outputs --workers 8 --retention-hours 48 + +# Direct server launch (without startup script) +python fastapi_server.py --debug --save-outputs --port 7860 +``` + +### Code Quality +```bash +# Linting (configured in pyproject.toml) +ruff check +ruff format + +# The project uses ruff for linting with custom rules: +# - Ignores line length violations (E501) +# - Ignores complexity (C901), naming (E741), regex (W605) +# - Line length set to 119 characters +``` + +## Architecture Overview + +### Core Components + +1. **DiaModel** (`dia/layers.py`): Main transformer architecture with encoder-decoder structure + - Encoder: Processes text input with standard transformer layers + - Decoder: Generates audio tokens using grouped-query attention and cross-attention to encoder + +2. **Audio Processing** (`dia/audio.py`): Handles delay patterns for multi-channel audio generation + - `apply_audio_delay()`: Applies channel-specific delays during generation + - `revert_audio_delay()`: Reverts delays for final audio reconstruction + +3. **Configuration System** (`dia/config.py`): Pydantic-based config management + - DataConfig: Text/audio lengths, padding, delay patterns + - ModelConfig: Architecture parameters (layers, heads, dimensions) + - DiaConfig: Master configuration combining all components + +4. **State Management** (`dia/state.py`): Handles inference state for encoder/decoder + - EncoderInferenceState: Manages text encoding and padding + - DecoderInferenceState: Manages KV caches and cross-attention during generation + - DecoderOutput: Tracks generated tokens and manages prefill/generation phases + +### Generation Process + +1. **Text Encoding**: Text is byte-encoded with special tokens [S1]/[S2] replaced by \x01/\x02 +2. **Audio Prompt Processing**: Optional audio prompts are encoded using DAC (Descript Audio Codec) +3. **Dual Path Generation**: Uses classifier-free guidance with conditional/unconditional paths +4. **Delay Pattern Application**: Multi-channel audio uses staggered generation with channel-specific delays +5. **Token Sampling**: Supports temperature, top-p, and top-k sampling with EOS handling +6. **Audio Reconstruction**: Generated codes are converted back to waveforms via DAC decoder + +### Key Design Patterns + +- **Batched Generation**: Supports generating multiple audio sequences simultaneously +- **Torch Compilation**: Optional torch.compile support for faster inference (use_torch_compile=True) +- **Device Flexibility**: Auto-detects CUDA, MPS (Apple Silicon), or CPU +- **Memory Efficiency**: Uses grouped-query attention and optional float16/bfloat16 precision + +## Text Format Requirements + +- Always start with `[S1]` and alternate between `[S1]` and `[S2]` speakers +- Keep input length moderate (5-20 seconds of audio equivalent) +- Supported nonverbals: `(laughs)`, `(coughs)`, `(clears throat)`, `(sighs)`, `(gasps)`, etc. +- For voice cloning: provide transcript of source audio before generation text + +## Model Loading Patterns + +```python +# From Hugging Face Hub (default) +model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16") + +# From local files +model = Dia.from_local("config.json", "model.pth", compute_dtype="float16") + +# With custom device +model = Dia.from_pretrained("nari-labs/Dia-1.6B", device=torch.device("cuda:0")) +``` + +## Environment Variables + +- `HF_TOKEN`: Required for downloading models from Hugging Face Hub +- `GRADIO_SERVER_NAME`: Override Gradio server host (use "0.0.0.0" for Docker) +- `GRADIO_SERVER_PORT`: Override Gradio server port + +## Hardware Requirements + +- GPU: ~10GB VRAM for float16/bfloat16, ~13GB for float32 +- CPU: Supported but significantly slower +- Apple Silicon: Use `use_torch_compile=False` (MPS doesn't support torch.compile) + +## FastAPI Server (SillyTavern Integration) + +The `fastapi_server.py` provides OpenAI-compatible TTS API with advanced features including worker queues, voice cloning, and job management. + +### Core Features +- **Worker Queue System** - Concurrent processing with up to 4 workers +- **Sync/Async Processing** - Choose immediate response or job-based processing +- **Voice Cloning** - Upload audio samples for custom voices +- **Model Parameters** - Full control over generation settings +- **Debug Logging** - Comprehensive monitoring and file management + +### Key Endpoints +- `POST /generate` - Main TTS generation (sync/async modes) +- `GET /models` - List available models +- `GET /voices` - List available voices +- `GET /jobs` - Job queue management +- `GET /config` - Server configuration +- `POST /audio_prompts/upload` - Upload voice samples + +### SillyTavern Configuration + +**Setup Instructions:** + +1. **Start the FastAPI Server:** + ```bash + # Recommended: Use startup script with environment check + python start_server.py + + # Development mode with full features + python start_server.py --dev + + # Production mode + python start_server.py --production + + # Custom configuration + python start_server.py --debug --save-outputs --workers 6 + ``` + +2. **Configure SillyTavern:** + - Navigate to Settings → Text-to-Speech + - Set TTS Provider: **Custom** + - Endpoint URL: **http://localhost:7860/generate** + - Voice: Choose from alloy, echo, fable, nova, onyx, shimmer + +3. **Available Features:** + - **Synchronous TTS** - Immediate audio response + - **Asynchronous TTS** - Job-based processing for concurrent requests + - **Custom Voices** - Upload audio samples for voice cloning + - **Model Parameters** - Configure temperature, cfg_scale, top_p + - **Queue Monitoring** - Track job status and worker utilization + +**Troubleshooting:** +- Ensure HF_TOKEN environment variable is set +- Check server logs for model loading status +- Test with: `curl http://localhost:7860/health` + +## Voice Cloning and Custom Voices + +The FastAPI server supports custom voice creation through audio prompts: + +### Upload Audio Prompt +```bash +curl -X POST "http://localhost:7860/audio_prompts/upload" \ + -F "prompt_id=my_voice" \ + -F "audio_file=@voice_sample.wav" +``` + +### Create Custom Voice Mapping +```bash +curl -X POST "http://localhost:7860/voice_mappings" \ + -H "Content-Type: application/json" \ + -d '{ + "voice_id": "custom_voice", + "style": "friendly", + "primary_speaker": "S1", + "audio_prompt": "my_voice" + }' +``` + +### Voice Management Endpoints +- `GET /voice_mappings` - List all voice configurations +- `PUT /voice_mappings/{voice_id}` - Update voice configuration +- `POST /voice_mappings` - Create new voice mapping +- `DELETE /voice_mappings/{voice_id}` - Delete custom voice +- `POST /audio_prompts/upload` - Upload audio prompt file +- `GET /audio_prompts` - List uploaded audio prompts +- `DELETE /audio_prompts/{prompt_id}` - Delete audio prompt + +### Audio Prompt Requirements +- Supported formats: WAV, MP3, M4A, etc. +- Automatically resampled to 44.1kHz mono +- Recommended: 3-10 seconds of clean speech +- For best results: Include the speaker's voice saying the target text or similar content + +## Debug and Logging Features + +The FastAPI server includes comprehensive logging and debugging capabilities: + +### Startup Script Options +```bash +# Basic startup with environment check +python start_server.py + +# Development mode (enables debug, save outputs, show prompts, reload) +python start_server.py --dev + +# Production mode (optimized settings) +python start_server.py --production + +# Custom configuration +python start_server.py --debug --save-outputs --show-prompts --workers 8 + +# Performance tuning +python start_server.py --workers 6 --no-torch-compile --retention-hours 48 + +# Environment check only +python start_server.py --check-only +``` + +### Direct Server Options +```bash +# Direct server launch (bypass startup script) +python fastapi_server.py --debug --save-outputs --show-prompts + +# Custom retention period +python fastapi_server.py --save-outputs --retention-hours 48 + +# Development mode with all features +python fastapi_server.py --debug --save-outputs --show-prompts --reload +``` + +### Configuration API +```bash +# Get current configuration +curl http://localhost:7860/config + +# Update configuration +curl -X PUT "http://localhost:7860/config" \ + -H "Content-Type: application/json" \ + -d '{ + "debug_mode": true, + "save_outputs": true, + "show_prompts": true, + "output_retention_hours": 24 + }' +``` + +### Generation Logs API +```bash +# Get recent generation logs +curl http://localhost:7860/logs + +# Get logs for specific voice +curl "http://localhost:7860/logs?voice=alloy&limit=10" + +# Get specific log details +curl http://localhost:7860/logs/{log_id} + +# Download generated audio file +curl http://localhost:7860/logs/{log_id}/download -o output.wav + +# Clear all logs +curl -X DELETE http://localhost:7860/logs + +# Manual cleanup of old files +curl -X POST http://localhost:7860/cleanup +``` + +### Features +- **Prompt Logging**: Shows original and processed text for each request +- **Audio File Saving**: Saves all generated audio with timestamps +- **Generation Metrics**: Tracks generation time and file sizes +- **Automatic Cleanup**: Removes files older than retention period (default 24h) +- **Debug Headers**: Includes generation ID in response headers +- **Voice Tracking**: Logs which voice and audio prompts were used +- **Job Queue Monitoring**: Track async jobs and worker status +- **Worker Pool Management**: Concurrent processing with configurable workers + +## Worker Queue System + +The FastAPI server includes a built-in worker queue for handling concurrent TTS requests: + +### Queue Features +- **Concurrent Processing**: Up to 4 workers process jobs simultaneously +- **Job Status Tracking**: Monitor jobs from pending → processing → completed +- **Automatic Cleanup**: Jobs cleaned up after 1 hour +- **Result Storage**: Audio results temporarily stored in memory +- **Worker Monitoring**: Track active workers and queue statistics + +### Processing Modes +```bash +# Synchronous (immediate response) +curl -X POST "http://localhost:7860/generate" \ + -d '{"text": "Hello!", "voice_id": "alloy"}' \ + --output speech.wav + +# Asynchronous (job-based) +curl -X POST "http://localhost:7860/generate?async_mode=true" \ + -d '{"text": "Hello!", "voice_id": "alloy"}' +``` + +### Queue Management +```bash +# Get queue statistics +curl "http://localhost:7860/queue/stats" + +# List all jobs +curl "http://localhost:7860/jobs" + +# Check specific job status +curl "http://localhost:7860/jobs/{job_id}" + +# Download completed job result +curl "http://localhost:7860/jobs/{job_id}/result" -o result.wav + +# Cancel pending job +curl -X DELETE "http://localhost:7860/jobs/{job_id}" +``` + +## API Features Summary + +### Voice Management +- **Built-in Voices**: alloy, echo, fable, nova, onyx, shimmer +- **Custom Voices**: Upload audio samples for voice cloning +- **Voice Mapping**: Configure speaker assignments and styles +- **Audio Prompts**: 0.5-60 seconds, auto-resampled to 44.1kHz + +### Model Parameters +- **temperature** (0.1-2.0): Controls randomness and creativity +- **cfg_scale** (1.0-10.0): Classifier-free guidance strength +- **top_p** (0.0-1.0): Nucleus sampling threshold +- **max_tokens** (100-10000): Maximum generation length +- **use_torch_compile** (boolean): Enable compilation optimization + +### API Request Format +```json +{ + "model": "dia", + "input": "[S1] Text to convert to speech", + "voice": "alloy", + "response_format": "wav", + "speed": 1.0, + "temperature": 1.2, + "cfg_scale": 3.0, + "top_p": 0.95 +} +``` + +### Performance & Scaling +- **Concurrent Workers**: 4 worker threads by default +- **Queue Management**: Job-based async processing +- **Memory Efficient**: Shared model instance across workers +- **Auto-cleanup**: Jobs and files removed automatically +- **Monitoring**: Full visibility into queue and worker status + +## Docker Support + +```bash +# GPU build +docker build -f docker/Dockerfile.gpu -t dia-gpu . + +# CPU build +docker build -f docker/Dockerfile.cpu -t dia-cpu . +``` \ No newline at end of file diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 00000000..de5bffda --- /dev/null +++ b/benchmark.py @@ -0,0 +1,249 @@ +import argparse +import time +import os +import numpy as np +import torch +import statistics +from pathlib import Path +from typing import List, Dict, Tuple + +from dia.model import Dia, ComputeDtype + + +def verify_cuda_setup() -> bool: + """Verify CUDA is available and properly configured.""" + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available on this system.") + return False + + # Print CUDA device information + device_count = torch.cuda.device_count() + print(f"CUDA Information:") + print(f"- CUDA Available: {torch.cuda.is_available()}") + print(f"- Number of CUDA Devices: {device_count}") + + for i in range(device_count): + device = torch.cuda.device(i) + props = torch.cuda.get_device_properties(device) + print(f"- Device {i}: {props.name}") + print(f" - CUDA Capability: {props.major}.{props.minor}") + print(f" - Total Memory: {props.total_memory / 1024**3:.2f} GB") + + return True + + +def run_benchmark( + model: Dia, + texts: List[str], + max_tokens: int, + cfg_scale: float, + temperature: float, + top_p: float, + num_runs: int, + warm_up: bool = True, + detailed_timing: bool = False, +) -> Dict: + """Run benchmark for the model on CUDA. + + Args: + model: The Dia model instance + texts: List of texts to generate audio from + max_tokens: Maximum number of tokens to generate + cfg_scale: Classifier-free guidance scale + temperature: Sampling temperature + top_p: Top-p sampling parameter + num_runs: Number of benchmark runs + warm_up: Whether to do a warm-up run (not counted in results) + detailed_timing: Whether to collect detailed timing per token + + Returns: + Dictionary of performance metrics + """ + # Verify model is on CUDA + if not next(model.model.parameters()).is_cuda: + print("WARNING: Model does not appear to be on CUDA!") + + # Optionally do a warm-up run to ensure any initialization happens before benchmarking + if warm_up: + print("Warming up GPU...") + _ = model.generate( + texts[0], + max_tokens=max_tokens, + cfg_scale=cfg_scale, + temperature=temperature, + top_p=top_p, + cfg_filter_top_k=45 + ) + + # Benchmark metrics + generation_times = [] + tokens_per_second = [] + memory_usage = [] + + for i in range(num_runs): + # Rotate through the texts if we have more runs than texts + text = texts[i % len(texts)] + + # Clear CUDA cache before each run + torch.cuda.empty_cache() + start_memory = torch.cuda.memory_allocated() / (1024**2) # Convert to MB + + # Time the generation + start_time = time.time() + output = model.generate( + text, + max_tokens=max_tokens, + cfg_scale=cfg_scale, + temperature=temperature, + top_p=top_p, + cfg_filter_top_k=45 + ) + end_time = time.time() + + # Calculate metrics + elapsed = end_time - start_time + tokens_sec = max_tokens / elapsed + peak_memory = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB + memory_used = peak_memory - start_memory + + # Store results + generation_times.append(elapsed) + tokens_per_second.append(tokens_sec) + memory_usage.append(memory_used) + + # Print run information + print(f"Run {i+1}/{num_runs}:") + print(f" - Time: {elapsed:.2f} seconds") + print(f" - Speed: {tokens_sec:.2f} tokens/second") + print(f" - Memory: {memory_used:.2f} MB") + + # Calculate aggregate metrics + results = { + "times": generation_times, + "avg_time": statistics.mean(generation_times), + "median_time": statistics.median(generation_times), + "min_time": min(generation_times), + "max_time": max(generation_times), + "tokens_per_second": statistics.mean(tokens_per_second), + "avg_memory_mb": statistics.mean(memory_usage), + } + + return results + + +def print_gpu_results(results: Dict, compute_dtype: str) -> None: + """Print detailed GPU benchmark results.""" + print("\n" + "="*60) + print(f"GPU BENCHMARK RESULTS (Precision: {compute_dtype})") + print("="*60) + + print(f"Performance Metrics:") + print(f"- Average Generation Time: {results['avg_time']:.2f} seconds") + print(f"- Median Generation Time: {results['median_time']:.2f} seconds") + print(f"- Best Time: {results['min_time']:.2f} seconds") + print(f"- Worst Time: {results['max_time']:.2f} seconds") + print(f"- Average Generation Speed: {results['tokens_per_second']:.2f} tokens/second") + print(f"- Average Memory Usage: {results['avg_memory_mb']:.2f} MB") + + if len(results['times']) > 1: + std_dev = statistics.stdev(results['times']) + variance_pct = (std_dev / results['avg_time']) * 100 + print(f"- Run Consistency: {variance_pct:.2f}% variance") + + print("="*60) + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Dia model performance on CUDA") + parser.add_argument("--num-runs", type=int, default=3, help="Number of benchmark runs") + parser.add_argument("--max-tokens", type=int, default=860, help="Maximum tokens to generate") + parser.add_argument("--cfg-scale", type=float, default=3.0, help="Classifier-Free Guidance scale") + parser.add_argument("--temperature", type=float, default=1.3, help="Sampling temperature") + parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling value") + parser.add_argument("--compare-precision", action="store_true", help="Compare FP16 vs BF16 precision") + + args = parser.parse_args() + + # Verify CUDA setup + if not verify_cuda_setup(): + print("CUDA verification failed. Cannot continue with GPU benchmarking.") + return + + # Test texts to use for benchmarking + benchmark_texts = [ + "[S1] This is a test of the generation speed. [S2] The quick brown fox jumps over the lazy dog.", + "[S1] Let's benchmark this model to see how it performs. [S2] Performance testing is important.", + "[S1] CUDA should be faster than CPU for neural networks. [S2] But by how much? Let's find out." + ] + + # CUDA device + cuda_device = torch.device("cuda") + + # If comparing precision modes + if args.compare_precision and torch.cuda.is_bf16_supported(): + # Float16 benchmarking + print("\nRunning GPU benchmark with FP16 precision...") + fp16_model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=cuda_device) + fp16_results = run_benchmark( + model=fp16_model, + texts=benchmark_texts, + max_tokens=args.max_tokens, + cfg_scale=args.cfg_scale, + temperature=args.temperature, + top_p=args.top_p, + num_runs=args.num_runs + ) + del fp16_model + torch.cuda.empty_cache() + + # BFloat16 benchmarking + print("\nRunning GPU benchmark with BF16 precision...") + bf16_model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16", device=cuda_device) + bf16_results = run_benchmark( + model=bf16_model, + texts=benchmark_texts, + max_tokens=args.max_tokens, + cfg_scale=args.cfg_scale, + temperature=args.temperature, + top_p=args.top_p, + num_runs=args.num_runs + ) + del bf16_model + torch.cuda.empty_cache() + + # Print results + print_gpu_results(fp16_results, "FP16") + print_gpu_results(bf16_results, "BF16") + + # Compare the two + speed_diff = (bf16_results['tokens_per_second'] / fp16_results['tokens_per_second'] - 1) * 100 + time_diff = (fp16_results['avg_time'] / bf16_results['avg_time'] - 1) * 100 + + print("\n" + "="*60) + print("PRECISION COMPARISON") + print("="*60) + print(f"BF16 vs FP16 Speed Difference: {speed_diff:.2f}% ({'faster' if speed_diff > 0 else 'slower'})") + print(f"BF16 vs FP16 Time Difference: {time_diff:.2f}% ({'faster' if time_diff > 0 else 'slower'})") + print("="*60) + + else: + # Standard FP16 benchmarking + print("\nRunning GPU benchmark with FP16 precision...") + model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=cuda_device) + results = run_benchmark( + model=model, + texts=benchmark_texts, + max_tokens=args.max_tokens, + cfg_scale=args.cfg_scale, + temperature=args.temperature, + top_p=args.top_p, + num_runs=args.num_runs + ) + print_gpu_results(results, "FP16") + + if args.compare_precision and not torch.cuda.is_bf16_supported(): + print("\nWarning: BF16 comparison requested but your GPU does not support BF16.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/example/fastapi_example.py b/example/fastapi_example.py new file mode 100644 index 00000000..ee8a70b7 --- /dev/null +++ b/example/fastapi_example.py @@ -0,0 +1,114 @@ +""" +Example usage of the Dia FastAPI server +""" + +import requests +import io +import soundfile as sf + +# Server configuration +SERVER_URL = "http://localhost:8000" + +def test_tts_api(): + """Test the TTS API endpoint""" + + # Test text + text = "Hello, this is a test of the Dia text-to-speech system. How does this sound?" + + # Make request to TTS endpoint + response = requests.post( + f"{SERVER_URL}/v1/audio/speech", + json={ + "model": "dia", + "input": text, + "voice": "alloy", + "response_format": "wav", + "speed": 1.0 + } + ) + + if response.status_code == 200: + # Save the audio file + with open("fastapi_test_output.wav", "wb") as f: + f.write(response.content) + print("✅ TTS generation successful! Audio saved to 'fastapi_test_output.wav'") + + # Load and print audio info + audio_data, sample_rate = sf.read(io.BytesIO(response.content)) + duration = len(audio_data) / sample_rate + print(f"📊 Audio info: {duration:.2f}s duration, {sample_rate}Hz sample rate") + else: + print(f"❌ TTS generation failed: {response.status_code}") + print(f"Error: {response.text}") + +def test_voices_api(): + """Test the voices listing endpoint""" + + response = requests.get(f"{SERVER_URL}/v1/voices") + + if response.status_code == 200: + voices = response.json() + print("✅ Available voices:") + for voice in voices: + print(f" - {voice['name']} (ID: {voice['voice_id']})") + else: + print(f"❌ Failed to get voices: {response.status_code}") + +def test_health(): + """Test server health""" + + response = requests.get(f"{SERVER_URL}/health") + + if response.status_code == 200: + health = response.json() + print("✅ Server health:") + print(f" Status: {health['status']}") + print(f" Model loaded: {health['model_loaded']}") + else: + print(f"❌ Health check failed: {response.status_code}") + +def test_alternative_endpoint(): + """Test the alternative TTS endpoint (SillyTavern-Extras style)""" + + response = requests.post( + f"{SERVER_URL}/api/tts/generate", + json={ + "text": "This is a test using the alternative API endpoint.", + "speaker": "nova" + } + ) + + if response.status_code == 200: + with open("fastapi_alt_test_output.wav", "wb") as f: + f.write(response.content) + print("✅ Alternative API test successful! Audio saved to 'fastapi_alt_test_output.wav'") + else: + print(f"❌ Alternative API test failed: {response.status_code}") + +if __name__ == "__main__": + print("🧪 Testing Dia FastAPI Server") + print("=" * 40) + + # Test server health first + test_health() + print() + + # Test voices endpoint + test_voices_api() + print() + + # Test main TTS endpoint + test_tts_api() + print() + + # Test alternative endpoint + test_alternative_endpoint() + print() + + print("🎉 All tests completed!") + print() + print("💡 SillyTavern Configuration:") + print(" TTS Provider: OpenAI Compatible") + print(" Model: dia") + print(" API Key: not-needed") + print(f" Endpoint URL: {SERVER_URL}/v1/audio/speech") \ No newline at end of file diff --git a/fastapi_server.py b/fastapi_server.py new file mode 100644 index 00000000..b798c221 --- /dev/null +++ b/fastapi_server.py @@ -0,0 +1,1359 @@ +""" +FastAPI Server for Dia TTS Model - SillyTavern Compatible +""" + +import io +import os +import re +import tempfile +import time +import json +import uuid +import threading +import asyncio +import multiprocessing as mp +from datetime import datetime, timedelta +from typing import Optional, Dict, Any, List +from enum import Enum +from queue import Queue +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import soundfile as sf +import torch +import uvicorn +from fastapi import FastAPI, HTTPException, Response, UploadFile, File, Form, Query, BackgroundTasks +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, FileResponse +from pydantic import BaseModel, Field +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import Depends, Header + +from dia.model import Dia + + +# Request/Response Models +class TTSRequest(BaseModel): + text: str = Field(..., max_length=4096, description="Text to convert to speech") + voice_id: str = Field(default="alloy", description="Voice identifier") + response_format: str = Field(default="wav", description="Audio format (wav, mp3)") + speed: float = Field(default=1.0, ge=0.25, le=4.0, description="Speech speed") + role: Optional[str] = Field(default=None, description="Role of the speaker (user, assistant, system)") + + # Dia model parameters + temperature: Optional[float] = Field(default=None, ge=0.1, le=2.0, description="Sampling temperature") + cfg_scale: Optional[float] = Field(default=None, ge=1.0, le=10.0, description="Classifier-free guidance scale") + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Top-p sampling") + max_tokens: Optional[int] = Field(default=None, ge=100, le=10000, description="Maximum tokens to generate") + use_torch_compile: Optional[bool] = Field(default=None, description="Enable torch.compile optimization") + + +class VoiceInfo(BaseModel): + name: str + voice_id: str + preview_url: Optional[str] = None + + +class TTSGenerateRequest(BaseModel): + """Legacy format for backwards compatibility""" + text: str = Field(..., max_length=4096) + voice_id: str = Field(default="alloy") + + +class ModelInfo(BaseModel): + id: str + object: str = "model" + created: int + owned_by: str + + +class ModelsResponse(BaseModel): + object: str = "list" + data: list[ModelInfo] + + +class VoiceMapping(BaseModel): + style: str + primary_speaker: str + audio_prompt: Optional[str] = None + audio_prompt_transcript: Optional[str] = None + + +class VoiceMappingUpdate(BaseModel): + voice_id: str + style: Optional[str] = None + primary_speaker: Optional[str] = None + audio_prompt: Optional[str] = None + audio_prompt_transcript: Optional[str] = None + + +class ServerConfig(BaseModel): + debug_mode: bool = False + save_outputs: bool = False + show_prompts: bool = False + output_retention_hours: int = 24 + + +class JobStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class TTSJob(BaseModel): + id: str + status: JobStatus + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + text: str + processed_text: Optional[str] = None + voice_id: str + speed: float + role: Optional[str] = None + temperature: Optional[float] = None + cfg_scale: Optional[float] = None + top_p: Optional[float] = None + max_tokens: Optional[int] = None + use_torch_compile: Optional[bool] = None + audio_prompt_used: bool = False + generation_time: Optional[float] = None + file_path: Optional[str] = None + file_size: Optional[int] = None + error_message: Optional[str] = None + worker_id: Optional[str] = None + + +class GenerationLog(BaseModel): + id: str + timestamp: datetime + text: str + processed_text: str + voice: str + audio_prompt_used: bool + generation_time: float + file_path: Optional[str] = None + file_size: Optional[int] = None + + +class QueueStats(BaseModel): + pending_jobs: int + processing_jobs: int + completed_jobs: int + failed_jobs: int + total_workers: int + active_workers: int + + +# Initialize FastAPI app +app = FastAPI( + title="Dia TTS Server", + description="FastAPI server for Dia text-to-speech model, compatible with SillyTavern", + version="1.0.0" +) + +# Add CORS middleware for web compatibility +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global model instance +model: Optional[Dia] = None + +# Security (optional, accepts any bearer token) +security = HTTPBearer(auto_error=False) + +# Voice mapping (Dia uses speaker tags [S1]/[S2], we'll map common voice names) +VOICE_MAPPING: Dict[str, Dict[str, Any]] = { + "alloy": {"style": "neutral", "primary_speaker": "S1", "audio_prompt": None, "audio_prompt_transcript": None}, + "echo": {"style": "calm", "primary_speaker": "S1", "audio_prompt": None, "audio_prompt_transcript": None}, + "fable": {"style": "expressive", "primary_speaker": "S2", "audio_prompt": None, "audio_prompt_transcript": None}, + "nova": {"style": "friendly", "primary_speaker": "S1", "audio_prompt": None, "audio_prompt_transcript": None}, + "onyx": {"style": "deep", "primary_speaker": "S2", "audio_prompt": None, "audio_prompt_transcript": None}, + "shimmer": {"style": "bright", "primary_speaker": "S1", "audio_prompt": None, "audio_prompt_transcript": None}, +} + +# Store uploaded audio prompts (now stores file paths) +AUDIO_PROMPTS: Dict[str, str] = {} +AUDIO_PROMPT_DIR = "audio_prompts" + +# Server configuration +SERVER_CONFIG = ServerConfig() + +# Generation logs +GENERATION_LOGS: Dict[str, GenerationLog] = {} + +# Job queue and management +JOB_QUEUE: Dict[str, TTSJob] = {} +JOB_RESULTS: Dict[str, bytes] = {} # Store audio results in memory + +# Output directory for saved files +OUTPUT_DIR = "audio_outputs" + +# Worker management +WORKER_POOL: Optional[ThreadPoolExecutor] = None +MAX_WORKERS = int(os.getenv("DIA_MAX_WORKERS", min(4, mp.cpu_count()))) # Configurable via env var +ACTIVE_WORKERS: Dict[str, bool] = {} + + +def load_model(): + """Load the Dia model on startup""" + global model + + if model is not None: + return + + print("Loading Dia model...") + try: + # Determine device + if torch.cuda.is_available(): + device = torch.device("cuda") + # Use bfloat16 if available (faster based on benchmarks), fall back to float16 + if torch.cuda.is_bf16_supported(): + compute_dtype = "bfloat16" + print("Using BFloat16 precision for better performance") + else: + compute_dtype = "float16" + print("BFloat16 not supported, using Float16") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = torch.device("mps") + compute_dtype = "float16" + else: + device = torch.device("cpu") + compute_dtype = "float32" + + print(f"Using device: {device}, compute_dtype: {compute_dtype}") + + model = Dia.from_pretrained( + "nari-labs/Dia-1.6B", + compute_dtype=compute_dtype, + device=device + ) + print("Dia model loaded successfully!") + + except Exception as e: + print(f"Error loading Dia model: {e}") + raise RuntimeError(f"Failed to load Dia model: {e}") + + +def preprocess_text(text: str, voice_id: str, role: Optional[str] = None) -> str: + """Preprocess text for Dia model requirements""" + # Remove asterisks (common in chat applications) + text = re.sub(r'\*+', '', text) + + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text).strip() + + # Ensure text has proper speaker tags for Dia + if not ('[S1]' in text or '[S2]' in text): + # Determine speaker based on role if provided + if role: + # Map roles to speakers: user -> S1, assistant/system -> S2 + if role.lower() == "user": + primary_speaker = "S1" + elif role.lower() in ["assistant", "system"]: + primary_speaker = "S2" + else: + # Unknown role, fall back to voice mapping + voice_config = VOICE_MAPPING.get(voice_id, VOICE_MAPPING["alloy"]) + primary_speaker = voice_config["primary_speaker"] + else: + # No role provided, use voice mapping to determine primary speaker + voice_config = VOICE_MAPPING.get(voice_id, VOICE_MAPPING["alloy"]) + primary_speaker = voice_config["primary_speaker"] + + # Wrap text with proper closing tags: [S1] text [S1] + text = f"[{primary_speaker}] {text} [{primary_speaker}]" + else: + # Ensure existing tags are properly closed + # Simple approach: if we find an opening tag without a closing tag, add it + if '[S1]' in text and not text.endswith('[S1]'): + if not text.endswith('[S2]'): + text += ' [S1]' + elif '[S2]' in text and not text.endswith('[S2]'): + if not text.endswith('[S1]'): + text += ' [S2]' + + return text + + +def can_use_torch_compile() -> bool: + """Check if torch.compile can be used safely""" + # Check if disabled via environment variable + if os.getenv("DIA_DISABLE_TORCH_COMPILE", "").lower() in ("1", "true", "yes"): + return False + + try: + # Only try torch.compile on CUDA with proper compiler setup + if not torch.cuda.is_available(): + return False + + # Check if we're on Windows and don't have proper compiler + import platform + if platform.system() == "Windows": + # On Windows, torch.compile often fails without proper MSVC setup + return False + + # Try a simple compilation test + @torch.compile + def test_fn(x): + return x + 1 + + test_tensor = torch.tensor([1.0]) + test_fn(test_tensor) + return True + except Exception: + return False + + +def ensure_output_dir(): + """Ensure output directory exists""" + if SERVER_CONFIG.save_outputs and not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR, exist_ok=True) + + +def ensure_audio_prompt_dir(): + """Ensure audio prompt directory exists""" + if not os.path.exists(AUDIO_PROMPT_DIR): + os.makedirs(AUDIO_PROMPT_DIR, exist_ok=True) + + +def cleanup_old_files(): + """Remove files older than retention period""" + if not SERVER_CONFIG.save_outputs or not os.path.exists(OUTPUT_DIR): + return + + cutoff_time = datetime.now() - timedelta(hours=SERVER_CONFIG.output_retention_hours) + + # Clean up files + for filename in os.listdir(OUTPUT_DIR): + file_path = os.path.join(OUTPUT_DIR, filename) + if os.path.isfile(file_path): + file_time = datetime.fromtimestamp(os.path.getctime(file_path)) + if file_time < cutoff_time: + try: + os.remove(file_path) + print(f"Cleaned up old file: {filename}") + except Exception as e: + print(f"Failed to clean up {filename}: {e}") + + # Clean up logs for deleted files + logs_to_remove = [] + for log_id, log in GENERATION_LOGS.items(): + if log.file_path and not os.path.exists(log.file_path): + logs_to_remove.append(log_id) + + for log_id in logs_to_remove: + del GENERATION_LOGS[log_id] + + +def process_tts_job(job_id: str) -> None: + """Process a TTS job in worker thread""" + worker_id = f"worker_{threading.current_thread().ident}" + ACTIVE_WORKERS[worker_id] = True + + try: + job = JOB_QUEUE.get(job_id) + if not job: + return + + # Update job status + job.status = JobStatus.PROCESSING + job.started_at = datetime.now() + job.worker_id = worker_id + + if SERVER_CONFIG.debug_mode: + print(f"Worker {worker_id} processing job {job_id}") + + # Generate audio + audio_data, log_id = generate_audio_from_text( + job.text, + job.voice_id, + job.speed, + job.temperature, + job.cfg_scale, + job.top_p, + job.max_tokens, + job.use_torch_compile, + job.role + ) + + # Convert to bytes for storage + with io.BytesIO() as buffer: + if audio_data.dtype != np.float32: + audio_data = audio_data.astype(np.float32) + sf.write(buffer, audio_data, 44100, format='WAV', subtype='PCM_16') + buffer.seek(0) + audio_bytes = buffer.getvalue() + + # Store result + JOB_RESULTS[job_id] = audio_bytes + + # Update job + job.status = JobStatus.COMPLETED + job.completed_at = datetime.now() + job.generation_time = (job.completed_at - job.started_at).total_seconds() + + if SERVER_CONFIG.debug_mode: + print(f"Job {job_id} completed in {job.generation_time:.2f}s") + + except Exception as e: + job = JOB_QUEUE.get(job_id) + if job: + job.status = JobStatus.FAILED + job.completed_at = datetime.now() + job.error_message = str(e) + + if SERVER_CONFIG.debug_mode: + print(f"Job {job_id} failed: {e}") + + finally: + ACTIVE_WORKERS[worker_id] = False + + +def generate_audio_from_text( + text: str, + voice_id: str = "alloy", + speed: float = 1.0, + temperature: Optional[float] = None, + cfg_scale: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + use_torch_compile: Optional[bool] = None, + role: Optional[str] = None +) -> tuple[np.ndarray, str]: + """Generate audio using Dia model and return (audio, log_id)""" + if model is None: + raise RuntimeError("Model not loaded") + + start_time = time.time() + log_id = str(uuid.uuid4()) + + # Preprocess text + processed_text = preprocess_text(text, voice_id, role) + + # Get voice configuration + voice_config = VOICE_MAPPING.get(voice_id, VOICE_MAPPING["alloy"]) + + # Add audio prompt transcript if available (for better voice cloning) + if voice_config.get("audio_prompt_transcript") and voice_config.get("audio_prompt"): + # Prepend the transcript for better voice cloning results + processed_text = voice_config["audio_prompt_transcript"] + " " + processed_text + + # Get audio prompt if available + audio_prompt = None + audio_prompt_used = False + if voice_config.get("audio_prompt"): + audio_prompt_path = AUDIO_PROMPTS.get(voice_config["audio_prompt"]) + if audio_prompt_path and os.path.exists(audio_prompt_path): + # Model expects file path for audio prompt + audio_prompt = audio_prompt_path + audio_prompt_used = True + else: + if audio_prompt_path: + print(f"Warning: Audio prompt file not found: {audio_prompt_path}") + audio_prompt_used = False + + # Set default parameters + generation_params = { + "temperature": temperature or 1.2, + "cfg_scale": cfg_scale or 3.0, + "top_p": top_p or 0.95, + "max_tokens": max_tokens, + "use_torch_compile": use_torch_compile if use_torch_compile is not None else can_use_torch_compile(), + "verbose": SERVER_CONFIG.debug_mode + } + + # Debug logging + if SERVER_CONFIG.debug_mode or SERVER_CONFIG.show_prompts: + print(f"\n=== Generation Request ===") + print(f"ID: {log_id}") + print(f"Original text: {text}") + print(f"Processed text: {processed_text}") + print(f"Voice ID: {voice_id}") + if role: + print(f"Role: {role}") + print(f"Audio prompt used: {audio_prompt_used}") + if audio_prompt_used and voice_config.get("audio_prompt_transcript"): + print(f"Audio prompt transcript included: Yes") + print(f"Speed: {speed}") + print(f"Parameters: {generation_params}") + + # Generate audio + try: + # Generate with proper audio prompt handling + audio_output = model.generate( + processed_text, + audio_prompt=audio_prompt, # Dia will handle the batching internally + **generation_params + ) + + # Apply speed adjustment if needed + if speed != 1.0 and audio_output is not None: + # Simple speed adjustment by resampling + original_len = len(audio_output) + target_len = int(original_len / speed) + if target_len > 0: + x_original = np.arange(original_len) + x_resampled = np.linspace(0, original_len - 1, target_len) + audio_output = np.interp(x_resampled, x_original, audio_output) + + generation_time = time.time() - start_time + + # Save output file if enabled + file_path = None + file_size = None + if SERVER_CONFIG.save_outputs and audio_output is not None: + ensure_output_dir() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}_{log_id[:8]}_{voice_id}.wav" + file_path = os.path.join(OUTPUT_DIR, filename) + + # Save audio file + sf.write(file_path, audio_output, 44100, format='WAV', subtype='PCM_16') + file_size = os.path.getsize(file_path) + + # Create log entry + log_entry = GenerationLog( + id=log_id, + timestamp=datetime.now(), + text=text, + processed_text=processed_text, + voice=voice_id, + audio_prompt_used=audio_prompt_used, + generation_time=generation_time, + file_path=file_path, + file_size=file_size + ) + GENERATION_LOGS[log_id] = log_entry + + if SERVER_CONFIG.debug_mode: + print(f"Generation completed in {generation_time:.2f}s") + if file_path: + print(f"Saved to: {file_path}") + + return audio_output, log_id + + except Exception as e: + # If torch.compile fails, try again without it + if "Compiler:" in str(e) and "not found" in str(e): + print("Torch compile failed, retrying without compilation...") + try: + # Retry with compilation disabled + retry_params = generation_params.copy() + retry_params["use_torch_compile"] = False + + audio_output = model.generate( + processed_text, + audio_prompt=audio_prompt, + **retry_params + ) + + # Apply speed adjustment if needed + if speed != 1.0 and audio_output is not None: + original_len = len(audio_output) + target_len = int(original_len / speed) + if target_len > 0: + x_original = np.arange(original_len) + x_resampled = np.linspace(0, original_len - 1, target_len) + audio_output = np.interp(x_resampled, x_original, audio_output) + + generation_time = time.time() - start_time + + # Save and log as above + file_path = None + file_size = None + if SERVER_CONFIG.save_outputs and audio_output is not None: + ensure_output_dir() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}_{log_id[:8]}_{voice_id}.wav" + file_path = os.path.join(OUTPUT_DIR, filename) + sf.write(file_path, audio_output, 44100, format='WAV', subtype='PCM_16') + file_size = os.path.getsize(file_path) + + log_entry = GenerationLog( + id=log_id, + timestamp=datetime.now(), + text=text, + processed_text=processed_text, + voice=voice_id, + audio_prompt_used=audio_prompt_used, + generation_time=generation_time, + file_path=file_path, + file_size=file_size + ) + GENERATION_LOGS[log_id] = log_entry + + return audio_output, log_id + except Exception as retry_e: + print(f"Retry without compilation also failed: {retry_e}") + raise HTTPException(status_code=500, detail=f"Audio generation failed: {retry_e}") + + print(f"Error generating audio: {e}") + print(f"Text: {processed_text}") + print(f"Voice: {voice_id}") + print(f"Audio prompt available: {audio_prompt is not None}") + raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") + + +def initialize_worker_pool(): + """Initialize the worker thread pool""" + global WORKER_POOL + WORKER_POOL = ThreadPoolExecutor(max_workers=MAX_WORKERS, thread_name_prefix="TTS-Worker") + print(f"Initialized worker pool with {MAX_WORKERS} workers") + + +@app.on_event("startup") +async def startup_event(): + """Load model on startup""" + load_model() + initialize_worker_pool() + + # Ensure audio prompt directory exists + ensure_audio_prompt_dir() + + # Load existing audio prompts from disk + if os.path.exists(AUDIO_PROMPT_DIR): + for filename in os.listdir(AUDIO_PROMPT_DIR): + if filename.endswith('.wav'): + prompt_id = filename[:-4] # Remove .wav extension + file_path = os.path.join(AUDIO_PROMPT_DIR, filename) + AUDIO_PROMPTS[prompt_id] = file_path + print(f"Loaded audio prompt: {prompt_id}") + + # Start cleanup task + def cleanup_task(): + while True: + cleanup_old_files() + # Clean up old job results (keep for 1 hour) + cleanup_old_jobs() + time.sleep(3600) # Run every hour + + cleanup_thread = threading.Thread(target=cleanup_task, daemon=True) + cleanup_thread.start() + + +def cleanup_old_jobs(): + """Clean up completed jobs and results older than 1 hour""" + cutoff_time = datetime.now() - timedelta(hours=1) + + jobs_to_remove = [] + for job_id, job in JOB_QUEUE.items(): + if job.completed_at and job.completed_at < cutoff_time: + jobs_to_remove.append(job_id) + + for job_id in jobs_to_remove: + JOB_QUEUE.pop(job_id, None) + JOB_RESULTS.pop(job_id, None) + if SERVER_CONFIG.debug_mode: + print(f"Cleaned up old job: {job_id}") + + +@app.on_event("shutdown") +async def shutdown_event(): + """Cleanup on shutdown""" + global WORKER_POOL + if WORKER_POOL: + WORKER_POOL.shutdown(wait=True) + print("Worker pool shut down") + + +@app.get("/") +async def root(): + """Health check endpoint""" + return {"message": "Dia TTS Server is running", "status": "healthy"} + + +@app.get("/models") +async def list_models(): + """List available models""" + return { + "models": [ + { + "id": "dia", + "name": "Dia TTS", + "description": "1.6B parameter text-to-speech model for dialogue generation" + } + ] + } + + +@app.get("/voices") +async def list_voices(): + """List available voices""" + voices = [] + for voice_id, config in VOICE_MAPPING.items(): + voices.append({ + "id": voice_id, + "name": voice_id, + "style": config["style"], + "primary_speaker": config["primary_speaker"], + "has_audio_prompt": config.get("audio_prompt") is not None, + "preview_url": f"/preview/{voice_id}" + }) + return {"voices": voices} + + +@app.post("/generate") +async def generate_speech( + request: TTSRequest, + async_mode: bool = Query(default=False, description="Return job ID for async processing") +): + """Main TTS generation endpoint - supports both sync and async modes""" + if not request.text.strip(): + raise HTTPException(status_code=400, detail="Text cannot be empty") + + if async_mode: + # Async mode: return job ID immediately + job_id = str(uuid.uuid4()) + job = TTSJob( + id=job_id, + status=JobStatus.PENDING, + created_at=datetime.now(), + text=request.text, + voice_id=request.voice_id, + speed=request.speed, + role=request.role, + temperature=request.temperature, + cfg_scale=request.cfg_scale, + top_p=request.top_p, + max_tokens=request.max_tokens, + use_torch_compile=request.use_torch_compile + ) + + JOB_QUEUE[job_id] = job + + # Submit to worker pool + if WORKER_POOL: + WORKER_POOL.submit(process_tts_job, job_id) + else: + raise HTTPException(status_code=503, detail="Worker pool not available") + + return {"job_id": job_id, "status": "pending", "message": "Job queued for processing"} + + else: + # Sync mode: traditional immediate response + try: + # Generate audio + audio_data, log_id = generate_audio_from_text( + request.text, + request.voice_id, + request.speed, + temperature=request.temperature, + cfg_scale=request.cfg_scale, + top_p=request.top_p, + max_tokens=request.max_tokens, + use_torch_compile=request.use_torch_compile, + role=request.role + ) + + if audio_data is None: + raise HTTPException(status_code=500, detail="Failed to generate audio") + + # Convert to bytes and create streaming response + def generate_audio_stream(): + with io.BytesIO() as buffer: + # Ensure audio is in the right format + if audio_data.dtype != np.float32: + audio_data_processed = audio_data.astype(np.float32) + else: + audio_data_processed = audio_data + + # Write audio file + sf.write(buffer, audio_data_processed, 44100, format='WAV', subtype='PCM_16') + buffer.seek(0) + + # Stream in chunks + chunk_size = 8192 + while True: + chunk = buffer.read(chunk_size) + if not chunk: + break + yield chunk + + # Determine media type and filename + if request.response_format.lower() == "mp3": + media_type = "audio/wav" # Still return WAV for now + filename = "speech.wav" + else: + media_type = "audio/wav" + filename = "speech.wav" + + response_headers = { + "Content-Disposition": f"attachment; filename={filename}", + "Transfer-Encoding": "chunked" + } + + # Add log ID header if debug mode + if SERVER_CONFIG.debug_mode: + response_headers["X-Generation-ID"] = log_id + + return StreamingResponse( + generate_audio_stream(), + media_type=media_type, + headers=response_headers + ) + + except Exception as e: + print(f"Error in generate_speech: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/tts/generate") +async def generate_speech_alt(request: TTSGenerateRequest): + """Legacy TTS endpoint for backwards compatibility""" + tts_request = TTSRequest( + text=request.text, + voice_id=request.voice_id + ) + return await generate_speech(tts_request) + + +@app.get("/api/tts/speakers") +async def list_speakers(): + """List speakers (alternative format)""" + return list(VOICE_MAPPING.keys()) + + +@app.get("/preview/{voice_id}") +async def get_voice_preview(voice_id: str): + """Generate a preview sample for a voice""" + if voice_id not in VOICE_MAPPING: + raise HTTPException(status_code=404, detail="Voice not found") + + preview_text = f"[S1] Hello, this is a preview of the {voice_id} voice. [S2] How does this sound to you?" + + try: + audio_data, log_id = generate_audio_from_text(preview_text, voice_id) + + with io.BytesIO() as buffer: + sf.write(buffer, audio_data, 44100, format='WAV', subtype='PCM_16') + buffer.seek(0) + audio_bytes = buffer.getvalue() + + return Response( + content=audio_bytes, + media_type="audio/wav", + headers={"Content-Disposition": f"attachment; filename=preview_{voice_id}.wav"} + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Preview generation failed: {e}") + + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return { + "status": "healthy", + "model_loaded": model is not None, + "timestamp": time.time() + } + + +# Voice Management Endpoints + +@app.get("/voice_mappings") +async def get_voice_mappings(): + """Get current voice mappings""" + return VOICE_MAPPING + + +@app.put("/voice_mappings/{voice_id}") +async def update_voice_mapping(voice_id: str, update: VoiceMappingUpdate): + """Update voice mapping configuration""" + if voice_id not in VOICE_MAPPING: + raise HTTPException(status_code=404, detail=f"Voice '{voice_id}' not found") + + # Update voice configuration + if update.style is not None: + VOICE_MAPPING[voice_id]["style"] = update.style + if update.primary_speaker is not None: + VOICE_MAPPING[voice_id]["primary_speaker"] = update.primary_speaker + if update.audio_prompt is not None: + VOICE_MAPPING[voice_id]["audio_prompt"] = update.audio_prompt + if update.audio_prompt_transcript is not None: + VOICE_MAPPING[voice_id]["audio_prompt_transcript"] = update.audio_prompt_transcript + + return {"message": f"Voice '{voice_id}' updated successfully", "voice_config": VOICE_MAPPING[voice_id]} + + +@app.post("/voice_mappings") +async def create_voice_mapping(mapping: VoiceMappingUpdate): + """Create new voice mapping""" + if not mapping.voice_id: + raise HTTPException(status_code=400, detail="voice_id is required") + + VOICE_MAPPING[mapping.voice_id] = { + "style": mapping.style or "neutral", + "primary_speaker": mapping.primary_speaker or "S1", + "audio_prompt": mapping.audio_prompt, + "audio_prompt_transcript": mapping.audio_prompt_transcript + } + + return {"message": f"Voice '{mapping.voice_id}' created successfully", "voice_config": VOICE_MAPPING[mapping.voice_id]} + + +@app.delete("/voice_mappings/{voice_id}") +async def delete_voice_mapping(voice_id: str): + """Delete voice mapping (only custom voices, not defaults)""" + default_voices = {"alloy", "echo", "fable", "nova", "onyx", "shimmer"} + + if voice_id in default_voices: + raise HTTPException(status_code=400, detail=f"Cannot delete default voice '{voice_id}'") + + if voice_id not in VOICE_MAPPING: + raise HTTPException(status_code=404, detail=f"Voice '{voice_id}' not found") + + # Remove associated audio prompt + if VOICE_MAPPING[voice_id].get("audio_prompt"): + audio_prompt_id = VOICE_MAPPING[voice_id]["audio_prompt"] + AUDIO_PROMPTS.pop(audio_prompt_id, None) + + del VOICE_MAPPING[voice_id] + return {"message": f"Voice '{voice_id}' deleted successfully"} + + +@app.post("/audio_prompts/upload") +async def upload_audio_prompt( + prompt_id: str = Form(...), + audio_file: UploadFile = File(...) +): + """Upload audio file to use as voice prompt""" + # Basic validation + if not prompt_id or not prompt_id.strip(): + raise HTTPException(status_code=400, detail="prompt_id cannot be empty") + + if not audio_file or not audio_file.filename: + raise HTTPException(status_code=400, detail="No audio file provided") + + # Content type validation (more flexible) + valid_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.aac'} + file_ext = os.path.splitext(audio_file.filename.lower())[1] + + if file_ext not in valid_extensions: + raise HTTPException( + status_code=400, + detail=f"Unsupported file format. Supported: {', '.join(valid_extensions)}" + ) + + temp_file_path = None + try: + # Read audio file data + audio_data = await audio_file.read() + + if len(audio_data) == 0: + raise HTTPException(status_code=400, detail="Audio file is empty") + + # Create temporary file with proper extension + temp_file_fd, temp_file_path = tempfile.mkstemp(suffix=file_ext) + + try: + # Write data to temp file + with os.fdopen(temp_file_fd, 'wb') as temp_file: + temp_file.write(audio_data) + temp_file.flush() + os.fsync(temp_file.fileno()) # Ensure data is written to disk + + # Load audio with error handling + try: + audio_array, sample_rate = sf.read(temp_file_path) + except Exception as sf_error: + raise HTTPException( + status_code=400, + detail=f"Cannot read audio file. Please check the file format: {str(sf_error)}" + ) + + # Validate audio data + if len(audio_array) == 0: + raise HTTPException(status_code=400, detail="Audio file contains no audio data") + + # Convert to mono if stereo + if len(audio_array.shape) > 1: + audio_array = np.mean(audio_array, axis=1) + + # Validate duration (3-30 seconds recommended) + duration = len(audio_array) / sample_rate + if duration < 0.5: + raise HTTPException(status_code=400, detail="Audio file too short (minimum 0.5 seconds)") + if duration > 60: + raise HTTPException(status_code=400, detail="Audio file too long (maximum 60 seconds)") + + # Resample to 44.1kHz if needed + if sample_rate != 44100: + resample_ratio = 44100 / sample_rate + new_length = int(len(audio_array) * resample_ratio) + if new_length > 0: + audio_array = np.interp( + np.linspace(0, len(audio_array) - 1, new_length), + np.arange(len(audio_array)), + audio_array + ) + + # Normalize audio to prevent clipping + if np.max(np.abs(audio_array)) > 0: + audio_array = audio_array / np.max(np.abs(audio_array)) * 0.95 + + # Save audio file to disk for the model to load + ensure_audio_prompt_dir() + audio_prompt_path = os.path.join(AUDIO_PROMPT_DIR, f"{prompt_id}.wav") + + # Save as WAV file at 44.1kHz + sf.write(audio_prompt_path, audio_array, 44100, format='WAV', subtype='PCM_16') + + # Store the file path + AUDIO_PROMPTS[prompt_id] = audio_prompt_path + + return { + "message": f"Audio prompt '{prompt_id}' uploaded successfully", + "duration": len(audio_array) / 44100, + "sample_rate": 44100, + "original_sample_rate": sample_rate, + "channels": "mono" + } + + except HTTPException: + raise + except Exception as process_error: + raise HTTPException( + status_code=500, + detail=f"Error processing audio file: {str(process_error)}" + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Unexpected error during file upload: {str(e)}" + ) + finally: + # Always clean up temp file + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + except Exception as cleanup_error: + print(f"Warning: Failed to clean up temp file {temp_file_path}: {cleanup_error}") + + +@app.get("/audio_prompts") +async def list_audio_prompts(): + """List available audio prompts""" + prompts = {} + for prompt_id, file_path in AUDIO_PROMPTS.items(): + if os.path.exists(file_path): + try: + # Read file info + audio_data, sr = sf.read(file_path) + prompts[prompt_id] = { + "file_path": file_path, + "duration": len(audio_data) / sr, + "sample_rate": sr, + "exists": True + } + except Exception as e: + prompts[prompt_id] = { + "file_path": file_path, + "exists": True, + "error": str(e) + } + else: + prompts[prompt_id] = { + "file_path": file_path, + "exists": False + } + return prompts + + +@app.delete("/audio_prompts/{prompt_id}") +async def delete_audio_prompt(prompt_id: str): + """Delete audio prompt""" + if prompt_id not in AUDIO_PROMPTS: + raise HTTPException(status_code=404, detail=f"Audio prompt '{prompt_id}' not found") + + # Check if any voices are using this prompt + using_voices = [voice_id for voice_id, config in VOICE_MAPPING.items() + if config.get("audio_prompt") == prompt_id] + + if using_voices: + return { + "warning": f"Audio prompt '{prompt_id}' is used by voices: {using_voices}", + "message": "Remove from voice mappings first before deleting" + } + + # Delete the file + file_path = AUDIO_PROMPTS[prompt_id] + if os.path.exists(file_path): + try: + os.unlink(file_path) + except Exception as e: + print(f"Warning: Failed to delete audio file {file_path}: {e}") + + del AUDIO_PROMPTS[prompt_id] + return {"message": f"Audio prompt '{prompt_id}' deleted successfully"} + + +# Debug and Configuration Endpoints + +@app.get("/config") +async def get_server_config(): + """Get current server configuration""" + return SERVER_CONFIG + + +@app.put("/config") +async def update_server_config(config: ServerConfig): + """Update server configuration""" + global SERVER_CONFIG + SERVER_CONFIG = config + return {"message": "Configuration updated successfully", "config": SERVER_CONFIG} + + +@app.get("/logs") +async def get_generation_logs( + limit: int = Query(default=50, le=500), + voice: Optional[str] = Query(default=None) +): + """Get generation logs""" + logs = list(GENERATION_LOGS.values()) + + # Filter by voice if specified + if voice: + logs = [log for log in logs if log.voice == voice] + + # Sort by timestamp (newest first) + logs.sort(key=lambda x: x.timestamp, reverse=True) + + # Limit results + logs = logs[:limit] + + return { + "logs": logs, + "total": len(GENERATION_LOGS), + "filtered": len(logs) + } + + +@app.get("/logs/{log_id}") +async def get_generation_log(log_id: str): + """Get specific generation log""" + if log_id not in GENERATION_LOGS: + raise HTTPException(status_code=404, detail=f"Log '{log_id}' not found") + + return GENERATION_LOGS[log_id] + + +@app.get("/logs/{log_id}/download") +async def download_generation_output(log_id: str): + """Download the audio file for a specific generation""" + if log_id not in GENERATION_LOGS: + raise HTTPException(status_code=404, detail=f"Log '{log_id}' not found") + + log = GENERATION_LOGS[log_id] + if not log.file_path or not os.path.exists(log.file_path): + raise HTTPException(status_code=404, detail="Audio file not found or has been cleaned up") + + filename = os.path.basename(log.file_path) + return FileResponse( + log.file_path, + media_type="audio/wav", + filename=filename, + headers={"Content-Disposition": f"attachment; filename={filename}"} + ) + + +@app.delete("/logs") +async def clear_generation_logs(): + """Clear all generation logs""" + global GENERATION_LOGS + GENERATION_LOGS = {} + return {"message": "All generation logs cleared"} + + +@app.delete("/logs/{log_id}") +async def delete_generation_log(log_id: str): + """Delete specific generation log and its file""" + if log_id not in GENERATION_LOGS: + raise HTTPException(status_code=404, detail=f"Log '{log_id}' not found") + + log = GENERATION_LOGS[log_id] + + # Delete file if it exists + if log.file_path and os.path.exists(log.file_path): + try: + os.remove(log.file_path) + except Exception as e: + print(f"Failed to delete file {log.file_path}: {e}") + + # Delete log + del GENERATION_LOGS[log_id] + + return {"message": f"Log '{log_id}' and associated file deleted"} + + +# Job Management Endpoints + +@app.get("/jobs") +async def list_jobs( + status: Optional[JobStatus] = Query(default=None), + limit: int = Query(default=50, le=500) +): + """List jobs with optional status filter""" + jobs = list(JOB_QUEUE.values()) + + # Filter by status if specified + if status: + jobs = [job for job in jobs if job.status == status] + + # Sort by creation time (newest first) + jobs.sort(key=lambda x: x.created_at, reverse=True) + + # Limit results + jobs = jobs[:limit] + + return { + "jobs": jobs, + "total": len(JOB_QUEUE), + "filtered": len(jobs) + } + + +@app.get("/jobs/{job_id}") +async def get_job_status(job_id: str): + """Get status of specific job""" + if job_id not in JOB_QUEUE: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + + return JOB_QUEUE[job_id] + + +@app.get("/jobs/{job_id}/result") +async def get_job_result(job_id: str): + """Download result of completed job""" + if job_id not in JOB_QUEUE: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + + job = JOB_QUEUE[job_id] + + if job.status != JobStatus.COMPLETED: + if job.status == JobStatus.FAILED: + raise HTTPException(status_code=500, detail=f"Job failed: {job.error_message}") + else: + raise HTTPException(status_code=425, detail=f"Job not completed (status: {job.status})") + + if job_id not in JOB_RESULTS: + raise HTTPException(status_code=404, detail="Job result not found or expired") + + audio_bytes = JOB_RESULTS[job_id] + + return Response( + content=audio_bytes, + media_type="audio/wav", + headers={ + "Content-Disposition": f"attachment; filename=speech_{job_id[:8]}.wav", + "Content-Length": str(len(audio_bytes)) + } + ) + + +@app.delete("/jobs/{job_id}") +async def cancel_job(job_id: str): + """Cancel a pending job""" + if job_id not in JOB_QUEUE: + raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found") + + job = JOB_QUEUE[job_id] + + if job.status == JobStatus.PENDING: + job.status = JobStatus.CANCELLED + job.completed_at = datetime.now() + return {"message": f"Job '{job_id}' cancelled"} + else: + raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}") + + +@app.get("/queue/stats") +async def get_queue_stats(): + """Get queue statistics""" + stats = { + "pending_jobs": len([j for j in JOB_QUEUE.values() if j.status == JobStatus.PENDING]), + "processing_jobs": len([j for j in JOB_QUEUE.values() if j.status == JobStatus.PROCESSING]), + "completed_jobs": len([j for j in JOB_QUEUE.values() if j.status == JobStatus.COMPLETED]), + "failed_jobs": len([j for j in JOB_QUEUE.values() if j.status == JobStatus.FAILED]), + "total_workers": MAX_WORKERS, + "active_workers": len([w for w in ACTIVE_WORKERS.values() if w]) + } + + return QueueStats(**stats) + + +@app.delete("/jobs") +async def clear_completed_jobs(): + """Clear all completed and failed jobs""" + jobs_to_remove = [] + for job_id, job in JOB_QUEUE.items(): + if job.status in [JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]: + jobs_to_remove.append(job_id) + + for job_id in jobs_to_remove: + JOB_QUEUE.pop(job_id, None) + JOB_RESULTS.pop(job_id, None) + + return {"message": f"Cleared {len(jobs_to_remove)} completed jobs"} + + +@app.post("/cleanup") +async def manual_cleanup(): + """Manually trigger cleanup of old files and jobs""" + cleanup_old_files() + cleanup_old_jobs() + return {"message": "Cleanup completed"} + + + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Start Dia TTS FastAPI server") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=7860, help="Port to bind to") + parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + parser.add_argument("--save-outputs", action="store_true", help="Save audio outputs to files") + parser.add_argument("--show-prompts", action="store_true", help="Show prompts in console") + parser.add_argument("--retention-hours", type=int, default=24, help="File retention hours") + + args = parser.parse_args() + + # Update server config from command line args + if args.debug: + SERVER_CONFIG.debug_mode = True + if args.save_outputs: + SERVER_CONFIG.save_outputs = True + if args.show_prompts: + SERVER_CONFIG.show_prompts = True + SERVER_CONFIG.output_retention_hours = args.retention_hours + + print(f"Starting Dia TTS Server on {args.host}:{args.port}") + print("Make sure you have set the HF_TOKEN environment variable!") + print(f"SillyTavern endpoint: http://{args.host}:{args.port}/v1/audio/speech") + print(f"Configuration API: http://{args.host}:{args.port}/v1/config") + print(f"Generation logs: http://{args.host}:{args.port}/v1/logs") + print(f"Debug mode: {SERVER_CONFIG.debug_mode}") + print(f"Save outputs: {SERVER_CONFIG.save_outputs}") + print(f"Show prompts: {SERVER_CONFIG.show_prompts}") + print(f"Retention: {SERVER_CONFIG.output_retention_hours} hours") + + uvicorn.run( + "fastapi_server:app", + host=args.host, + port=args.port, + reload=args.reload + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index dd844dd2..80eed611 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ authors = [ ] dependencies = [ "descript-audio-codec>=1.0.0", + "fastapi>=0.104.0", "gradio>=5.25.2", "huggingface-hub>=0.30.2", "numpy>=2.2.4", @@ -20,6 +21,7 @@ dependencies = [ "torchaudio==2.6.0", "triton==3.2.0 ; sys_platform == 'linux'", "triton-windows==3.2.0.post18 ; sys_platform == 'win32'", + "uvicorn>=0.24.0", ] [build-system] diff --git a/start_server.py b/start_server.py new file mode 100644 index 00000000..b49aedba --- /dev/null +++ b/start_server.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Simple startup script for Dia FastAPI TTS Server +""" + +import argparse +import os +import sys +import subprocess + +def check_environment(): + """Check if the environment is properly set up""" + issues = [] + + # Check if HF_TOKEN is set + if not os.getenv("HF_TOKEN"): + issues.append("❌ HF_TOKEN environment variable not set") + issues.append(" Set it with: export HF_TOKEN='your_token_here'") + issues.append(" Get token from: https://huggingface.co/settings/tokens") + else: + print("✅ HF_TOKEN environment variable is set") + + # Check if required packages are available + try: + import fastapi + import uvicorn + import torch + print("✅ Required packages are available") + except ImportError as e: + issues.append(f"❌ Missing required package: {e}") + issues.append(" Install with: pip install -e .") + + # Check if CUDA is available + try: + import torch + if torch.cuda.is_available(): + print(f"✅ CUDA available with {torch.cuda.device_count()} GPU(s)") + else: + print("ℹ️ CUDA not available, will use CPU (slower)") + except: + pass + + return issues + +def main(): + parser = argparse.ArgumentParser(description="Start Dia TTS FastAPI Server") + + # Server configuration + parser.add_argument("--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)") + parser.add_argument("--port", type=int, default=7860, help="Port to bind to (default: 7860)") + parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") + parser.add_argument("--check-only", action="store_true", help="Only check environment, don't start server") + + # Debug and logging options + parser.add_argument("--debug", action="store_true", help="Enable debug mode with verbose logging") + parser.add_argument("--save-outputs", action="store_true", help="Save all generated audio files") + parser.add_argument("--show-prompts", action="store_true", help="Show prompts and processing details in console") + parser.add_argument("--retention-hours", type=int, default=24, help="File retention period in hours (default: 24)") + + # Performance options + parser.add_argument("--workers", type=int, help="Number of worker threads (default: auto-detect)") + parser.add_argument("--no-torch-compile", action="store_true", help="Disable torch.compile optimization") + + # Quick preset options + parser.add_argument("--dev", action="store_true", help="Development mode (debug + save outputs + show prompts + reload)") + parser.add_argument("--production", action="store_true", help="Production mode (optimized settings)") + + args = parser.parse_args() + + # Handle preset modes + if args.dev: + args.debug = True + args.save_outputs = True + args.show_prompts = True + args.reload = True + print("🔧 Development mode enabled") + + if args.production: + args.debug = False + args.save_outputs = False + args.show_prompts = False + args.reload = False + print("🏭 Production mode enabled") + + print("🚀 Dia TTS Server Startup") + print("=" * 40) + + # Show configuration + print("📋 Server Configuration:") + print(f" Debug mode: {'✅' if args.debug else '❌'}") + print(f" Save outputs: {'✅' if args.save_outputs else '❌'}") + print(f" Show prompts: {'✅' if args.show_prompts else '❌'}") + print(f" Auto-reload: {'✅' if args.reload else '❌'}") + print(f" Retention: {args.retention_hours} hours") + if args.workers: + print(f" Workers: {args.workers}") + print() + + # Check environment + issues = check_environment() + + if issues: + print("\n⚠️ Environment Issues:") + for issue in issues: + print(issue) + + if args.check_only: + sys.exit(1) + + print("\nDo you want to continue anyway? (y/N): ", end="") + if input().lower() != 'y': + sys.exit(1) + + if args.check_only: + print("\n✅ Environment check passed!") + return + + print(f"\n🌐 Starting server on {args.host}:{args.port}") + print("📋 SillyTavern Configuration:") + print(" TTS Provider: OpenAI Compatible") + print(" Model: dia") + print(" API Key: sk-anything") + print(f" Endpoint URL: http://{args.host}:{args.port}/v1/audio/speech") + print() + print("🔗 Server endpoints:") + print(f" Health Check: http://{args.host}:{args.port}/health") + print(f" API Docs: http://{args.host}:{args.port}/docs") + print(f" Voice List: http://{args.host}:{args.port}/v1/voices") + print(f" Queue Stats: http://{args.host}:{args.port}/v1/queue/stats") + if args.debug: + print(f" Config API: http://{args.host}:{args.port}/v1/config") + print(f" Generation Logs: http://{args.host}:{args.port}/v1/logs") + print() + print("Press Ctrl+C to stop the server") + print("=" * 40) + + # Build command + cmd = [ + sys.executable, "fastapi_server.py", + "--host", args.host, + "--port", str(args.port) + ] + + # Add flags + if args.reload: + cmd.append("--reload") + if args.debug: + cmd.append("--debug") + if args.save_outputs: + cmd.append("--save-outputs") + if args.show_prompts: + cmd.append("--show-prompts") + if args.retention_hours != 24: + cmd.extend(["--retention-hours", str(args.retention_hours)]) + + # Environment variables for advanced options + env = os.environ.copy() + if args.workers: + env["DIA_MAX_WORKERS"] = str(args.workers) + if args.no_torch_compile: + env["DIA_DISABLE_TORCH_COMPILE"] = "1" + + # Start the server + try: + if args.debug: + print(f"🔧 Command: {' '.join(cmd)}") + if args.workers: + print(f"🔧 Workers: {args.workers}") + if args.no_torch_compile: + print("🔧 Torch compile: disabled") + print() + + subprocess.run(cmd, env=env) + + except KeyboardInterrupt: + print("\n👋 Server stopped") + except Exception as e: + print(f"\n❌ Error starting server: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file