Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions api/src/routers/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,134 @@ async def combine_voices(request: Union[str, List[str]]):
"type": "server_error",
},
)

VOICE_WEIGHT_PATTERN = re.compile(r"^([a-zA-Z0-9_]+)(?::([\d.]+))?$")

@router.post("/audio/voices/combine_with_weights")
async def CombineVoicesWithWeights(request: Union[str, List[str]]):
"""Combine multiple voices using weighted blending and return the resulting .pt voice file.

Args:
request: Either a string where entries are separated by '+'
(e.g. "emma:0.7+isabella:0.3")
or a list of entries like ["emma:0.7", "isabella:0.3"].
Each entry can optionally include a weight (default = 1.0).

Returns:
FileResponse containing the combined voice .pt file.

Raises:
HTTPException:
- 400: Invalid input format, missing voices, negative weights, or voice not found.
- 500: Server error (file system issues, tensor loading errors, combination failure).
"""

try:
# 1️⃣ Parse and validate request entries
if isinstance(request, str):
entries = [v.strip() for v in request.split("+") if v.strip()]
else:
entries = [v.strip() for v in request if v.strip()]

if not entries:
raise ValueError("No voices provided")

voice_weights: List[Tuple[str, float]] = []
for entry in entries:
match = VOICE_WEIGHT_PATTERN.match(entry)
if not match:
raise ValueError(
f"Invalid format '{entry}'. Expected 'voiceName[:weight]'"
)

voice_name, weight_str = match.groups()
weight = float(weight_str) if weight_str is not None else 1.0

if weight < 0:
raise ValueError(f"Weight must be positive for voice '{voice_name}'")

voice_weights.append((voice_name, weight))

# 2️⃣ Normalize weights
total_weight = sum(w for _, w in voice_weights)
if total_weight <= 0:
raise ValueError("Total weight must be greater than zero")

voice_weights = [(n, w / total_weight) for n, w in voice_weights]

# 3️⃣ Load and blend voice tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
combined_tensor = None

for voice_name, weight in voice_weights:
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
VOICE_DIR = os.path.join(BASE_DIR, "voices", "v1_0")
voice_path = os.path.join(VOICE_DIR, f"{voice_name}.pt")

if not os.path.exists(voice_path):
raise ValueError(
f"Voice '{voice_name}' not found in directory {VOICE_DIR}"
)

# Optimized loading without unnecessary metadata
tensor = torch.load(voice_path, weights_only=True, map_location=device)

weighted_tensor = tensor * weight
if combined_tensor is None:
combined_tensor = weighted_tensor
else:
combined_tensor.add_(weighted_tensor)

if combined_tensor is None:
raise RuntimeError("Voice combination failed")

# 4️⃣ Save result to temporary file
combined_name = "_".join(f"{name}{int(weight * 100)}" for name, weight in voice_weights)
temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, f"{combined_name}.pt")

buffer = io.BytesIO()
torch.save(combined_tensor, buffer)
async with aiofiles.open(output_path, "wb") as f:
await f.write(buffer.getvalue())

return FileResponse(
output_path,
media_type="application/octet-stream",
filename=f"{combined_name}.pt",
headers={
"Content-Disposition": f"attachment; filename={combined_name}.pt",
"Cache-Control": "no-cache",
},
)

except ValueError as e:
logger.warning(f"Invalid weighted voice combination request: {str(e)}")
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error",
},
)
except RuntimeError as e:
logger.error(f"Weighted voice combination error: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": "Failed to process weighted voice combination request",
"type": "server_error",
},
)
except Exception as e:
logger.error(f"Unexpected error in weighted voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "An unexpected error occurred",
"type": "server_error",
},
)