diff --git a/app.py b/app.py index 81a42f35..2924aa83 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,6 @@ import argparse import contextlib +import gc import io import random import tempfile @@ -48,7 +49,21 @@ dtype = dtype_map.get(device.type, "float16") print(f"Using device: {device}, attempting to load model with {dtype}") - model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype=dtype, device=device) + # Step 1: Load model normally + model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype=dtype, device=device) + + # Step 2: Apply dynamic quantization + quantized_model = torch.quantization.quantize_dynamic( + model.model, {torch.nn.Linear, torch.nn.LSTM}, dtype=torch.qint8 + ) + + # Step 3: Dereference the original + model.model = None + torch.cuda.empty_cache() + + # Step 4: Replace with quantized + model.model = quantized_model + except Exception as e: print(f"Error loading Nari model: {e}") raise @@ -66,6 +81,84 @@ def set_seed(seed: int): torch.backends.cudnn.benchmark = False +def count_effective_length(text): + """Counts effective length treating [S1] and [S2] as single characters.""" + return len(text.replace("[S1]", "¤").replace("[S2]", "¤")) + + +def auto_adjust_chunk_size(text, user_chunk_size): + """Auto-adjusts chunk size if turbo mode is enabled.""" + effective_chars = count_effective_length(text) + if user_chunk_size > 0: + # If user explicitly sets a chunk size, respect it + return int(user_chunk_size) + else: + # Auto-tune based on input size + if effective_chars <= 1024: + return 48 + elif effective_chars <= 4096: + return 64 + else: + return 96 + + +def split_by_words_respecting_special_tokens(text, max_effective_chars=64): + """Splits text into chunks close to max_effective_chars, preserving full words and [S1]/[S2] markers.""" + words = text.split() + chunks = [] + current_chunk = "" + + for word in words: + tentative_chunk = (current_chunk + " " + word).strip() if current_chunk else word + if count_effective_length(tentative_chunk) > max_effective_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = word + else: + chunks.append(word) + current_chunk = "" + else: + current_chunk = tentative_chunk + + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def batch_chunks(chunks, batch_size): + """Yield successive batches of chunks.""" + for i in range(0, len(chunks), batch_size): + yield chunks[i : i + batch_size] + + +def split_lines_greedy(lines, chunk_size): + """Greedily split lines into chunks of up to chunk_size lines.""" + chunks = [] + i = 0 + while i < len(lines): + remaining = len(lines) - i + if remaining <= chunk_size: + chunks.append("\n".join(lines[i:])) + break + else: + chunks.append("\n".join(lines[i : i + chunk_size])) + i += chunk_size + return chunks + + +def set_seed(seed: int): + """Sets the random seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def run_inference( text_input: str, audio_prompt_text_input: str, @@ -76,154 +169,152 @@ def run_inference( top_p: float, cfg_filter_top_k: int, speed_factor: float, + chunk_size: int, seed: Optional[int] = None, ): """ Runs Nari inference using the globally loaded model and provided inputs. - Uses temporary files for text and audio prompt compatibility with inference.generate. + Supports dynamic chunking and token scaling. """ + global model, device # Access global model, config, device console_output_buffer = io.StringIO() with contextlib.redirect_stdout(console_output_buffer): - # Prepend transcript text if audio_prompt provided - if audio_prompt_input and audio_prompt_text_input and not audio_prompt_text_input.isspace(): - text_input = audio_prompt_text_input + "\n" + text_input - text_input = text_input.strip() + # Validation + if not text_input or text_input.isspace(): + raise gr.Error("Text input cannot be empty.") if audio_prompt_input and (not audio_prompt_text_input or audio_prompt_text_input.isspace()): raise gr.Error("Audio Prompt Text input cannot be empty.") - if not text_input or text_input.isspace(): - raise gr.Error("Text input cannot be empty.") + # Set and Display Generation Seed + if seed is None or seed < 0: + seed = random.randint(0, 2**32 - 1) + print(f"\nNo seed provided, generated random seed: {seed}\n") + else: + print(f"\nUsing user-selected seed: {seed}\n") + set_seed(seed) - # Preprocess Audio - temp_txt_file_path = None + # Preprocess audio prompt temp_audio_prompt_path = None output_audio = (44100, np.zeros(1, dtype=np.float32)) + prompt_path_for_generate = None try: - prompt_path_for_generate = None if audio_prompt_input is not None: sr, audio_data = audio_prompt_input - # Check if audio_data is valid - if audio_data is None or audio_data.size == 0 or audio_data.max() == 0: # Check for silence/empty + if audio_data is None or audio_data.size == 0 or audio_data.max() == 0: gr.Warning("Audio prompt seems empty or silent, ignoring prompt.") else: - # Save prompt audio to a temporary WAV file with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f_audio: - temp_audio_prompt_path = f_audio.name # Store path for cleanup + temp_audio_prompt_path = f_audio.name - # Basic audio preprocessing for consistency - # Convert to float32 in [-1, 1] range if integer type if np.issubdtype(audio_data.dtype, np.integer): max_val = np.iinfo(audio_data.dtype).max audio_data = audio_data.astype(np.float32) / max_val elif not np.issubdtype(audio_data.dtype, np.floating): - gr.Warning(f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion.") - # Attempt conversion, might fail for complex types try: audio_data = audio_data.astype(np.float32) except Exception as conv_e: - raise gr.Error(f"Failed to convert audio prompt to float32: {conv_e}") + raise gr.Error(f"Failed to convert audio prompt: {conv_e}") - # Ensure mono (average channels if stereo) if audio_data.ndim > 1: - if audio_data.shape[0] == 2: # Assume (2, N) - audio_data = np.mean(audio_data, axis=0) - elif audio_data.shape[1] == 2: # Assume (N, 2) - audio_data = np.mean(audio_data, axis=1) - else: - gr.Warning( - f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis." - ) - audio_data = ( - audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0] - ) - audio_data = np.ascontiguousarray(audio_data) # Ensure contiguous after slicing/mean - - # Write using soundfile - try: - sf.write( - temp_audio_prompt_path, audio_data, sr, subtype="FLOAT" - ) # Explicitly use FLOAT subtype - prompt_path_for_generate = temp_audio_prompt_path - print(f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})") - except Exception as write_e: - print(f"Error writing temporary audio file: {write_e}") - raise gr.Error(f"Failed to save audio prompt: {write_e}") - - # Set and Display Generation Seed - if seed is None or seed < 0: - seed = random.randint(0, 2**32 - 1) - print(f"\nNo seed provided, generated random seed: {seed}\n") - else: - print(f"\nUsing user-selected seed: {seed}\n") - set_seed(seed) + audio_data = np.mean(audio_data, axis=-1) + audio_data = np.ascontiguousarray(audio_data) + + sf.write(temp_audio_prompt_path, audio_data, sr, subtype="FLOAT") + prompt_path_for_generate = temp_audio_prompt_path + print(f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})") + + # --- Chunking --- + chunk_size = auto_adjust_chunk_size(text_input, chunk_size) + print(f"Auto-selected chunk size: {chunk_size} effective characters per chunk.") + # New: Split by effective character count (~64 chars per chunk) + chunks = split_by_words_respecting_special_tokens(text_input, max_effective_chars=chunk_size) + + print(f"Chunked into {len(chunks)} chunks (based on effective character count).") - # Run Generation - print(f'Generating speech: \n"{text_input}"\n') + audio_segments = [] start_time = time.time() - # Use torch.inference_mode() context manager for the generation call - with torch.inference_mode(): - output_audio_np = model.generate( - text_input, - max_tokens=max_new_tokens, - cfg_scale=cfg_scale, - temperature=temperature, - top_p=top_p, - cfg_filter_top_k=cfg_filter_top_k, # Pass the value here - use_torch_compile=False, # Keep False for Gradio stability - audio_prompt=prompt_path_for_generate, - verbose=True, + batch_size = 4 # Adjust based on your GPU VRAM (e.g., 2–8) + + for batch_idx, chunk_batch in enumerate(batch_chunks(chunks, batch_size)): + print( + f"Generating batch {batch_idx + 1}/{(len(chunks) + batch_size - 1) // batch_size} with {len(chunk_batch)} chunks..." ) + batch_input_text = "\n".join(chunk.strip() for chunk in chunk_batch).strip() + + if not batch_input_text: + raise gr.Error("All chunks in this batch were empty after trimming. Cannot generate.") + effective_chars = count_effective_length(batch_input_text) + scaling_factor = effective_chars / chunk_size + adjusted_tokens = int(max_new_tokens * scaling_factor) + adjusted_tokens = max(256, adjusted_tokens) + + with torch.inference_mode(), torch.amp.autocast(device_type="cuda", dtype=torch.float16): + generated_batch_audio = model.generate( + batch_input_text, + max_tokens=adjusted_tokens, + cfg_scale=cfg_scale, + temperature=temperature, + top_p=top_p, + cfg_filter_top_k=cfg_filter_top_k, + use_torch_compile=False, + audio_prompt=prompt_path_for_generate, + audio_prompt_text=audio_prompt_text_input, + ) + + if generated_batch_audio is not None: + audio_segments.append(generated_batch_audio) + + # Add a small silence buffer **after the batch** (but NOT after the last batch) + if batch_idx < (len(chunks) + batch_size - 1) // batch_size - 1: + silence_duration_sec = 0.2 + silence_samples = int(44100 * silence_duration_sec) + silence = np.zeros(silence_samples, dtype=np.float32) + audio_segments.append(silence) + + if not audio_segments: + output_audio_np = None + else: + output_audio_np = np.concatenate(audio_segments) + end_time = time.time() print(f"Generation finished in {end_time - start_time:.2f} seconds.\n") - # 4. Convert Codes to Audio + # --- Postprocessing --- if output_audio_np is not None: - # Get sample rate from the loaded DAC model output_sr = 44100 - # --- Slow down audio --- + # Slowdown if needed original_len = len(output_audio_np) - # Ensure speed_factor is positive and not excessively small/large to avoid issues speed_factor = max(0.1, min(speed_factor, 5.0)) - target_len = int(original_len / speed_factor) # Target length based on speed_factor - if target_len != original_len and target_len > 0: # Only interpolate if length changes and is valid + target_len = int(original_len / speed_factor) + + if target_len != original_len and target_len > 0: x_original = np.arange(original_len) x_resampled = np.linspace(0, original_len - 1, target_len) resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) - output_audio = ( - output_sr, - resampled_audio_np.astype(np.float32), - ) # Use resampled audio + output_audio = (output_sr, resampled_audio_np.astype(np.float32)) print( f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed." ) else: - output_audio = ( - output_sr, - output_audio_np, - ) # Keep original if calculation fails or no change + output_audio = (output_sr, output_audio_np) print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).") - # --- End slowdown --- - - print(f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}") - # Explicitly convert to int16 to prevent Gradio warning - if output_audio[1].dtype == np.float32 or output_audio[1].dtype == np.float64: + # Final output conversion + if output_audio[1].dtype in (np.float32, np.float64): audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) output_audio = (output_sr, audio_for_gradio) print("Converted audio to int16 for Gradio output.") - else: print("\nGeneration finished, but no valid tokens were produced.") - # Return default silence gr.Warning("Generation produced no output.") except Exception as e: @@ -231,27 +322,29 @@ def run_inference( import traceback traceback.print_exc() - # Re-raise as Gradio error to display nicely in the UI raise gr.Error(f"Inference failed: {e}") finally: - # Cleanup Temporary Files defensively - if temp_txt_file_path and Path(temp_txt_file_path).exists(): - try: - Path(temp_txt_file_path).unlink() - print(f"Deleted temporary text file: {temp_txt_file_path}") - except OSError as e: - print(f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}") + # Clean up temp files if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): try: Path(temp_audio_prompt_path).unlink() print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") - except OSError as e: - print(f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}") + except Exception as cleanup_e: + print(f"Warning: Error deleting temporary audio prompt file: {cleanup_e}") - # After generation, capture the printed output console_output = console_output_buffer.getvalue() + try: + torch.cuda.empty_cache() + gc.collect() + except RuntimeError as e: + print(f"CUDA cache clear failed: {e}") + except Exception as e: + print(f"Garbage collection failed: {e}") + finally: + print("Generation completed") + return output_audio, seed, console_output @@ -297,6 +390,14 @@ def run_inference( lines=5, # Increased lines ) with gr.Accordion("Generation Parameters", open=False): + chunk_size = gr.Number( + label="Chunk Size (Effective Characters)", + minimum=0, + value=0, + precision=0, + step=1, + info="If 0, auto-selects chunk size for optimal speed. Otherwise, set number of effective characters per generation chunk.", + ) max_new_tokens = gr.Slider( label="Max New Tokens (Audio Length)", minimum=860, @@ -378,6 +479,7 @@ def run_inference( top_p, cfg_filter_top_k, speed_factor_slider, + chunk_size, seed_input, ], outputs=[ @@ -398,8 +500,10 @@ def run_inference( 3.0, 1.8, 0.95, - 45, - 1.0, + 35, + 0.94, + 4, + -1, ], [ "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.", @@ -408,8 +512,10 @@ def run_inference( 3.0, 1.8, 0.95, - 45, - 1.0, + 35, + 0.94, + 4, + -1, ], ] @@ -425,6 +531,7 @@ def run_inference( top_p, cfg_filter_top_k, speed_factor_slider, + chunk_size, seed_input, ], outputs=[audio_output], diff --git a/dia/model.py b/dia/model.py index 56a34ef6..92eb4d2a 100644 --- a/dia/model.py +++ b/dia/model.py @@ -1,6 +1,6 @@ import time from enum import Enum -from typing import Callable +from typing import Callable, Optional import numpy as np import torch @@ -600,8 +600,9 @@ def generate( top_p: float = 0.95, use_torch_compile: bool = False, cfg_filter_top_k: int = 45, - audio_prompt: list[str | torch.Tensor | None] | str | torch.Tensor | None = None, - audio_prompt_path: list[str | torch.Tensor | None] | str | torch.Tensor | None = None, + audio_prompt: str | torch.Tensor | None = None, + audio_prompt_path: str | None = None, + audio_prompt_text: Optional[str] = None, use_cfg_filter: bool | None = None, verbose: bool = False, ) -> np.ndarray | list[np.ndarray]: @@ -647,6 +648,10 @@ def generate( if audio_prompt_path: print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.") audio_prompt = audio_prompt_path + if audio_prompt_text: + full_text = f"{audio_prompt_text.strip()}\n{text.strip()}" + else: + full_text = text.strip() if use_cfg_filter is not None: print("Warning: use_cfg_filter is deprecated.")