diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index c3252217..c44289ee 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -662,3 +662,134 @@ async def combine_voices(request: Union[str, List[str]]): "type": "server_error", }, ) + +VOICE_WEIGHT_PATTERN = re.compile(r"^([a-zA-Z0-9_]+)(?::([\d.]+))?$") + +@router.post("/audio/voices/combine_with_weights") +async def CombineVoicesWithWeights(request: Union[str, List[str]]): + """Combine multiple voices using weighted blending and return the resulting .pt voice file. + + Args: + request: Either a string where entries are separated by '+' + (e.g. "emma:0.7+isabella:0.3") + or a list of entries like ["emma:0.7", "isabella:0.3"]. + Each entry can optionally include a weight (default = 1.0). + + Returns: + FileResponse containing the combined voice .pt file. + + Raises: + HTTPException: + - 400: Invalid input format, missing voices, negative weights, or voice not found. + - 500: Server error (file system issues, tensor loading errors, combination failure). + """ + + try: + # 1️⃣ Parse and validate request entries + if isinstance(request, str): + entries = [v.strip() for v in request.split("+") if v.strip()] + else: + entries = [v.strip() for v in request if v.strip()] + + if not entries: + raise ValueError("No voices provided") + + voice_weights: List[Tuple[str, float]] = [] + for entry in entries: + match = VOICE_WEIGHT_PATTERN.match(entry) + if not match: + raise ValueError( + f"Invalid format '{entry}'. Expected 'voiceName[:weight]'" + ) + + voice_name, weight_str = match.groups() + weight = float(weight_str) if weight_str is not None else 1.0 + + if weight < 0: + raise ValueError(f"Weight must be positive for voice '{voice_name}'") + + voice_weights.append((voice_name, weight)) + + # 2️⃣ Normalize weights + total_weight = sum(w for _, w in voice_weights) + if total_weight <= 0: + raise ValueError("Total weight must be greater than zero") + + voice_weights = [(n, w / total_weight) for n, w in voice_weights] + + # 3️⃣ Load and blend voice tensors + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + combined_tensor = None + + for voice_name, weight in voice_weights: + BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + VOICE_DIR = os.path.join(BASE_DIR, "voices", "v1_0") + voice_path = os.path.join(VOICE_DIR, f"{voice_name}.pt") + + if not os.path.exists(voice_path): + raise ValueError( + f"Voice '{voice_name}' not found in directory {VOICE_DIR}" + ) + + # Optimized loading without unnecessary metadata + tensor = torch.load(voice_path, weights_only=True, map_location=device) + + weighted_tensor = tensor * weight + if combined_tensor is None: + combined_tensor = weighted_tensor + else: + combined_tensor.add_(weighted_tensor) + + if combined_tensor is None: + raise RuntimeError("Voice combination failed") + + # 4️⃣ Save result to temporary file + combined_name = "_".join(f"{name}{int(weight * 100)}" for name, weight in voice_weights) + temp_dir = tempfile.gettempdir() + output_path = os.path.join(temp_dir, f"{combined_name}.pt") + + buffer = io.BytesIO() + torch.save(combined_tensor, buffer) + async with aiofiles.open(output_path, "wb") as f: + await f.write(buffer.getvalue()) + + return FileResponse( + output_path, + media_type="application/octet-stream", + filename=f"{combined_name}.pt", + headers={ + "Content-Disposition": f"attachment; filename={combined_name}.pt", + "Cache-Control": "no-cache", + }, + ) + + except ValueError as e: + logger.warning(f"Invalid weighted voice combination request: {str(e)}") + raise HTTPException( + status_code=400, + detail={ + "error": "validation_error", + "message": str(e), + "type": "invalid_request_error", + }, + ) + except RuntimeError as e: + logger.error(f"Weighted voice combination error: {str(e)}") + raise HTTPException( + status_code=500, + detail={ + "error": "processing_error", + "message": "Failed to process weighted voice combination request", + "type": "server_error", + }, + ) + except Exception as e: + logger.error(f"Unexpected error in weighted voice combination: {str(e)}") + raise HTTPException( + status_code=500, + detail={ + "error": "server_error", + "message": "An unexpected error occurred", + "type": "server_error", + }, + )