diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 3c2d9c7bc..a69119f82 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -14,6 +14,7 @@ import { BookOpen, Workflow, Users, + Mic, } from "lucide-react" import { useState, useEffect } from "react" import HistoryModal from "@/components/HistoryModal" @@ -40,12 +41,12 @@ import { useUnreadCount } from "@/contexts/UnreadCountContext" export const Sidebar = ({ className = "", photoLink = "", - role = "", + role, isAgentMode = false, }: { className?: string photoLink?: string - role?: string + role?: UserRole isAgentMode?: boolean }) => { const location = useLocation() @@ -95,6 +96,7 @@ export const Sidebar = ({ const isInteractiveElement = target.closest(SELECTORS.INTERACTIVE_ELEMENT) if (isInteractiveElement) return + const isSidebarClick = target.closest(`.${CLASS_NAMES.SIDEBAR_CONTAINER}`) const isHistoryModalClick = target.closest( `.${CLASS_NAMES.HISTORY_MODAL_CONTAINER}`, @@ -104,6 +106,7 @@ export const Sidebar = ({ const isReferenceBox = target.closest(`.${CLASS_NAMES.REFERENCE_BOX}`) const isAtMentionArea = target.closest(SELECTORS.AT_MENTION_AREA) const isBookmarkButton = target.closest(`.${CLASS_NAMES.BOOKMARK_BUTTON}`) + if ( !isSidebarClick && !isHistoryModalClick && @@ -111,18 +114,15 @@ export const Sidebar = ({ !isSearchArea && !isReferenceBox && !isAtMentionArea && - !isBookmarkButton && - showHistory + !isBookmarkButton ) { - if (showHistory) setShowHistory(false) + setShowHistory(false) } } document.addEventListener("mousedown", handleClickOutside) return () => document.removeEventListener("mousedown", handleClickOutside) - }, [showHistory]) - - // toggleDarkMode is now toggleTheme from context (no separate function needed here) + }, []) return ( @@ -144,7 +144,8 @@ export const Sidebar = ({ Profile )} @@ -167,7 +168,15 @@ export const Sidebar = ({
setShowHistory((history) => !history)} + onKeyDown={(e) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault() + setShowHistory((history) => !history) + } + }} className={cn( "flex w-8 h-8 rounded-lg items-center justify-center cursor-pointer hover:bg-[#D8DFE680] dark:hover:bg-gray-700 mt-[10px]", showHistory && "bg-[#D8DFE680] dark:bg-gray-700", @@ -233,7 +242,26 @@ export const Sidebar = ({ - {/* TODO: Add appropriate Link destination and Tooltip info for the Bot icon */} + + + + + + + + + {isAgentMode && ( - {" "} - {/* Placeholder: Update this tooltip info */} + )} @@ -297,6 +328,7 @@ export const Sidebar = ({ + {/* User Management - Admin only */} {role === UserRole.SuperAdmin && ( )} +
+
{ + if (e.key === "Enter" || e.key === " ") { + e.preventDefault() + toggleTheme() + } + }} className="flex w-8 h-8 rounded-lg items-center justify-center cursor-pointer hover:bg-[#D8DFE680] dark:hover:bg-gray-700 mb-4" > @@ -358,6 +400,7 @@ export const Sidebar = ({ />
+ Logo diff --git a/frontend/src/routes/_authenticated/transcription.tsx b/frontend/src/routes/_authenticated/transcription.tsx new file mode 100644 index 000000000..fc0ae99a1 --- /dev/null +++ b/frontend/src/routes/_authenticated/transcription.tsx @@ -0,0 +1,610 @@ +import { createFileRoute } from "@tanstack/react-router" +import { Sidebar } from "@/components/Sidebar" +import { useState, useCallback, useEffect, useRef } from "react" +import { Button } from "@/components/ui/button" +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card" +import { Label } from "@/components/ui/label" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select" +import { Progress } from "@/components/ui/progress" +import { + Upload, + FileAudio, + CheckCircle2, + XCircle, + Loader2, + AlertCircle, +} from "lucide-react" +import { api } from "@/api" + +export const Route = createFileRoute("/_authenticated/transcription")({ + component: TranscriptionPage, + errorComponent: () =>
Error loading transcription page
, + loader: async ({ context }) => { + return { + user: context.user, + workspace: context.workspace, + } + }, +}) + +type JobStatus = + | "idle" + | "uploading" + | "queued" + | "processing" + | "completed" + | "failed" + +type OutputFormat = "json" | "txt" | "srt" | "all" +type WhisperModel = "turbo" | "large" + +interface TranscriptionResult { + json?: { + text: string + segments: Array<{ + speaker: string + text: string + start: number + end: number + }> + speakers: string[] + language: string + } + txt?: string + srt?: string +} + +function TranscriptionPage() { + const { user } = Route.useLoaderData() + + // Form state + const [selectedFile, setSelectedFile] = useState(null) + const [whisperModel, setWhisperModel] = useState("turbo") + const [outputFormat, setOutputFormat] = useState("json") + const [numSpeakers, setNumSpeakers] = useState(undefined) + + // Job state + const [jobStatus, setJobStatus] = useState("idle") + const [jobId, setJobId] = useState(null) + const [uploadProgress, setUploadProgress] = useState(0) + const [error, setError] = useState(null) + const [result, setResult] = useState(null) + + // Polling refs + const pollIntervalRef = useRef | null>(null) + const pollTimeoutRef = useRef | null>(null) + + const clearPolling = useCallback(() => { + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current) + pollIntervalRef.current = null + } + if (pollTimeoutRef.current) { + clearTimeout(pollTimeoutRef.current) + pollTimeoutRef.current = null + } + }, []) + + useEffect( + () => () => { + // Cleanup on unmount + clearPolling() + }, + [clearPolling], + ) + + const handleFileSelect = useCallback( + (event: React.ChangeEvent) => { + const file = event.target.files?.[0] + if (file) { + const validTypes = ["audio/", "video/"] + if (validTypes.some((type) => file.type.startsWith(type))) { + setSelectedFile(file) + setError(null) + setResult(null) + setJobId(null) + setJobStatus("idle") + } else { + setError("Please select a valid audio or video file") + setSelectedFile(null) + } + } + }, + [], + ) + + const uploadFile = async (file: File): Promise => { + const formData = new FormData() + formData.append("file", file) + + const response = await fetch(`${api}/files/upload-simple`, { + method: "POST", + headers: { + Authorization: `Bearer ${localStorage.getItem("access_token")}`, + }, + body: formData, + }) + + if (!response.ok) { + let message = "File upload failed" + try { + const data = await response.json() + if (data?.message) message = data.message + } catch { + // ignore JSON parse errors + } + throw new Error(message) + } + + const data = await response.json() + return data.url + } + + const pollJobStatus = useCallback( + (jobId: string) => { + clearPolling() + + const intervalId = setInterval(async () => { + try { + const response = await fetch(`${api}/asr/job-status?jobId=${jobId}`, { + headers: { + Authorization: `Bearer ${localStorage.getItem("access_token")}`, + }, + }) + + if (!response.ok) { + throw new Error("Failed to fetch job status") + } + + const data = await response.json() + + if (data.status === "completed") { + clearPolling() + setJobStatus("completed") + setResult(data.outputs ?? null) + } else if (data.status === "failed") { + clearPolling() + setJobStatus("failed") + setError(data.error || "Transcription failed") + } else if (data.status === "active") { + setJobStatus("processing") + } + } catch (err) { + clearPolling() + setJobStatus("failed") + setError( + err instanceof Error + ? err.message + : "Failed to check job status", + ) + } + }, 3000) + + pollIntervalRef.current = intervalId + + // Hard timeout after 30 minutes + const timeoutId = setTimeout(() => { + clearPolling() + setJobStatus("failed") + setError("Transcription timed out. Please try again.") + }, 30 * 60 * 1000) + + pollTimeoutRef.current = timeoutId + }, + [clearPolling], + ) + + const handleSubmit = async () => { + if (!selectedFile) { + setError("Please select a file first") + return + } + + try { + clearPolling() + setJobStatus("uploading") + setError(null) + setResult(null) + setJobId(null) + setUploadProgress(0) + + // Upload file + const audioUrl = await uploadFile(selectedFile) + setUploadProgress(50) + + // Start transcription job + setJobStatus("queued") + const response = await fetch(`${api}/asr/transcribe`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${localStorage.getItem("access_token")}`, + }, + body: JSON.stringify({ + audioUrl, + whisperModel, + refineWithLLM: true, + outputFormat, + numSpeakers: numSpeakers || undefined, + multilingual: true, + }), + }) + + if (!response.ok) { + let message = "Failed to start transcription" + try { + const data = await response.json() + if (data?.message) message = data.message + } catch { + // ignore JSON parse errors + } + throw new Error(message) + } + + const data = await response.json() + setJobId(data.jobId) + setUploadProgress(100) + + // Start polling for job status + pollJobStatus(data.jobId) + } catch (err) { + clearPolling() + setJobStatus("failed") + setError(err instanceof Error ? err.message : "An error occurred") + } + } + + const formatTime = (seconds: number): string => { + const mins = Math.floor(seconds / 60) + const secs = Math.floor(seconds % 60) + return `${mins}:${secs.toString().padStart(2, "0")}` + } + + const downloadResult = (content: string, filename: string) => { + const blob = new Blob([content], { type: "text/plain" }) + const url = URL.createObjectURL(blob) + const a = document.createElement("a") + a.href = url + a.download = filename + document.body.appendChild(a) + a.click() + document.body.removeChild(a) + URL.revokeObjectURL(url) + } + + const isJobIdleLike = + jobStatus === "idle" || jobStatus === "completed" || jobStatus === "failed" + + return ( +
+ + +
+
+
+ {/* Header */} +
+

+ Audio Transcription +

+

+ Upload audio files for automated transcription with speaker + diarization and AI refinement +

+
+ + {/* Upload Card */} + + + + Upload Audio File + + + Select an audio or video file to transcribe + + + + {/* File Upload */} +
+ + +
+ + {/* Configuration */} +
+ {/* Whisper Model */} +
+ + +
+ + {/* Output Format */} +
+ + +
+ + {/* Number of Speakers */} +
+ + +
+
+ + {/* Submit Button */} + +
+
+ + {/* Progress/Status Card */} + {jobStatus !== "idle" && ( + + + + {jobStatus === "completed" && ( + + )} + {jobStatus === "failed" && ( + + )} + {(jobStatus === "uploading" || + jobStatus === "queued" || + jobStatus === "processing") && ( + + )} + {jobStatus === "uploading" && "Uploading File"} + {jobStatus === "queued" && "Job Queued"} + {jobStatus === "processing" && "Transcribing Audio"} + {jobStatus === "completed" && "Transcription Complete"} + {jobStatus === "failed" && "Transcription Failed"} + + + + {jobStatus === "uploading" && ( + + )} + + {error && ( +
+ +

+ {error} +

+
+ )} + + {jobId && ( +

+ Job ID:{" "} + + {jobId} + +

+ )} +
+
+ )} + + {/* Results Card */} + {result && jobStatus === "completed" && ( + + + + Transcription Results + + + {result.json && + `Detected ${result.json.speakers.length} speaker(s): ${result.json.speakers.join( + ", ", + )}`} + + + + {/* Download Buttons */} +
+ {result.json && ( + + )} + {result.txt && ( + + )} + {result.srt && ( + + )} +
+ + {/* Transcript Preview */} + {result.json && ( +
+

+ Transcript Preview +

+ {result.json.segments.map((segment, idx) => ( +
+
+ + {segment.speaker} + + + {formatTime(segment.start)} -{" "} + {formatTime(segment.end)} + +
+

+ {segment.text} +

+
+ ))} +
+ )} +
+
+ )} +
+
+
+
+ ) +} diff --git a/server/api/asr.ts b/server/api/asr.ts new file mode 100644 index 000000000..22709e4e7 --- /dev/null +++ b/server/api/asr.ts @@ -0,0 +1,230 @@ +import { z } from "zod" +import { HTTPException } from "hono/http-exception" +import type { Context } from "hono" +import { getLogger } from "@/logger" +import { Subsystem } from "@/types" +import { promises as fs } from "fs" +import * as path from "path" +import { randomUUID } from "crypto" +import { boss, ASRQueue } from "@/queue" +import type { TranscribeJobData } from "@/queue/asrProcessor" +import { ASRJobType } from "@/queue/asrProcessor" +import config from "@/config" + +const Logger = getLogger(Subsystem.Api).child({ module: "asr" }) + +// Paths +const ASR_SD_DIR = path.join(process.cwd(), "asr-sd") + +// Schemas +export const transcribeAudioSchema = z.object({ + audioUrl: z.string({ message: "Invalid audio URL" }).url(), + whisperModel: z + .enum(["tiny", "base", "small", "medium", "large", "large-v2", "large-v3", "turbo"]) + .default("turbo") + .optional(), + language: z.string().optional(), + numSpeakers: z.number().int().positive().optional(), + minSpeakers: z.number().int().positive().optional(), + maxSpeakers: z.number().int().positive().optional(), + multilingual: z.boolean().default(true).optional(), + refineWithLLM: z.boolean().default(true).optional(), + outputFormat: z.enum(["txt", "json", "srt", "all"]).default("json").optional(), +}) + +export const getJobStatusSchema = z.object({ + jobId: z.string({ message: "Invalid job ID" }).uuid(), +}) + +// API: Transcribe audio with speaker diarization (enqueues job) +export const TranscribeAudioApi = async (c: Context) => { + try { + const body = c.req.valid("json" as never) as z.infer + const { + audioUrl, + whisperModel = "turbo", + language, + numSpeakers, + minSpeakers, + maxSpeakers, + outputFormat = "json", + } = body + + // We always run multilingual + LLM refinement in this pipeline. + const multilingual = true + const refineWithLLM = true + + // Basic sanity check for speaker bounds + if (minSpeakers !== undefined && maxSpeakers !== undefined && minSpeakers > maxSpeakers) { + throw new HTTPException(400, { + message: "minSpeakers cannot be greater than maxSpeakers.", + }) + } + + // Restrict audio URL to http/https to avoid weird protocols (file://, ftp, etc.) + let parsedUrl: URL + try { + parsedUrl = new URL(audioUrl) + } catch { + throw new HTTPException(400, { message: "Invalid audio URL" }) + } + + if (!["http:", "https:"].includes(parsedUrl.protocol)) { + throw new HTTPException(400, { + message: "audioUrl must use http or https scheme", + }) + } + + // Check which LLM provider is configured (following same logic as config.ts) + const { defaultBestModel } = config + if (!defaultBestModel) { + throw new HTTPException(400, { + message: + "No LLM provider configured for ASR refinement. " + + "Please configure an AI provider (Vertex AI, OpenAI, AWS Bedrock, etc.) in your environment.", + }) + } + + // Use HF_TOKEN from environment (required for diarization) + const hfToken = process.env.HF_TOKEN + if (!hfToken) { + throw new HTTPException(400, { + message: "HF_TOKEN not configured. Required for speaker diarization.", + }) + } + + const jobId = randomUUID() + const tempDir = path.join(ASR_SD_DIR, "temp", jobId) + + // Create temp directory + await fs.mkdir(tempDir, { recursive: true }) + + const audioExt = path.extname(parsedUrl.pathname) || ".mp3" + const audioPath = path.join(tempDir, `audio${audioExt}`) + const outputBase = path.join(tempDir, "transcription") + + // Create job data for unified automated pipeline + const jobData: TranscribeJobData = { + type: ASRJobType.Transcribe, + jobId, + audioUrl, + audioPath, + outputPath: outputBase, + whisperModel, + language, + numSpeakers, + minSpeakers, + maxSpeakers, + multilingual, // always true in this pipeline + refineWithLLM, // always true in this pipeline + hfToken, + outputFormat, + } + + // Enqueue job + Logger.info({ jobId, audioUrl }, "Enqueuing transcription job") + const queueJobId = await boss.send(ASRQueue, jobData, { + expireInHours: 24, + retryLimit: 2, + retryDelay: 60, + retryBackoff: true, + }) + + Logger.info({ jobId, queueJobId }, "Transcription job enqueued") + + return c.json({ + success: true, + jobId, + queueJobId, + status: "queued", + message: "Transcription job has been queued for processing", + }) + } catch (error) { + Logger.error({ error }, "Error in TranscribeAudioApi") + if (error instanceof HTTPException) { + throw error + } + throw new HTTPException(500, { + message: error instanceof Error ? error.message : "Unknown error occurred", + }) + } +} + +// API: Get job status and results +export const GetJobStatusApi = async (c: Context) => { + try { + const { jobId } = c.req.valid("query" as never) as z.infer + + Logger.info({ jobId }, "Getting job status") + + // Get job status from PgBoss + const job = await boss.getJobById(ASRQueue, jobId) + + if (!job) { + throw new HTTPException(404, { + message: "Job not found", + }) + } + + const tempDir = path.join(ASR_SD_DIR, "temp", jobId) + let outputs: Record = {} + + // If job is completed, read output files + if (job.state === "completed") { + try { + const jobData = job.data as TranscribeJobData + const suffix = jobData.refineWithLLM ? "_refined" : "_raw" + const format = jobData.outputFormat || "json" + + if (format === "json" || format === "all") { + const jsonPath = path.join(tempDir, `transcription${suffix}.json`) + try { + const jsonContent = await fs.readFile(jsonPath, "utf-8") + outputs.json = JSON.parse(jsonContent) + } catch (error) { + Logger.warn({ error, jsonPath }, "Failed to read JSON output") + } + } + + if (format === "txt" || format === "all") { + const txtPath = path.join(tempDir, `transcription${suffix}.txt`) + try { + outputs.txt = await fs.readFile(txtPath, "utf-8") + } catch (error) { + Logger.warn({ error, txtPath }, "Failed to read TXT output") + } + } + + if (format === "srt" || format === "all") { + const srtPath = path.join(tempDir, `transcription${suffix}.srt`) + try { + outputs.srt = await fs.readFile(srtPath, "utf-8") + } catch (error) { + Logger.warn({ error, srtPath }, "Failed to read SRT output") + } + } + } catch (error) { + Logger.warn({ error }, "Error reading output files") + } + } + + return c.json({ + success: true, + jobId, + status: job.state, + createdOn: job.createdOn, + startedOn: job.startedOn, + completedOn: job.completedOn, + outputs: Object.keys(outputs).length > 0 ? outputs : undefined, + error: (job as any).output?.error, + }) + } catch (error) { + Logger.error({ error }, "Error in GetJobStatusApi") + if (error instanceof HTTPException) { + throw error + } + throw new HTTPException(500, { + message: error instanceof Error ? error.message : "Unknown error occurred", + }) + } +} diff --git a/server/api/files.ts b/server/api/files.ts index 12f573257..8c4e3fd62 100644 --- a/server/api/files.ts +++ b/server/api/files.ts @@ -294,7 +294,7 @@ export const handleAttachmentUpload = async (c: Context) => { vespaId = `${fileId}_sheet_${(processingResults[0] as SheetProcessingResult).totalSheets}` } // Handle multiple processing results (e.g., for spreadsheets with multiple sheets) - for (const [resultIndex, processingResult] of processingResults.entries()) { + for (const processingResult of processingResults) { let docId = fileId let fileName = file.name @@ -671,3 +671,114 @@ export const handleThumbnailServe = async (c: Context) => { throw error } } +// Simple file upload handler for ASR (stores file temporarily and returns URL) +export const handleSimpleFileUpload = async (c: Context) => { + const logger = getLogger(Subsystem.Api).child({ module: "simpleFileUpload" }) + + try { + const { sub } = c.get(JwtPayloadKey) + const email = sub + + const formData = await c.req.formData() + const file = formData.get("file") + + if (!file || !(file instanceof File)) { + throw new HTTPException(400, { + message: "No valid file uploaded", + }) + } + + // Create uploads directory if it doesn't exist + const uploadsDir = join(process.cwd(), "uploads", "asr") + await mkdir(uploadsDir, { recursive: true }) + + // Generate unique filename + const timestamp = Date.now() + const originalName = file.name || "upload" + const safeBaseName = originalName.replace(/[^a-zA-Z0-9.-]/g, "_") + const filename = `${timestamp}-${safeBaseName}` + const filePath = join(uploadsDir, filename) + + // Write file to disk + const arrayBuffer = await file.arrayBuffer() + await fs.writeFile(filePath, Buffer.from(arrayBuffer)) + + logger.info({ email, filename, size: file.size, type: file.type }, "File uploaded successfully") + + // Return file URL (can be accessed via file serving endpoint) + const fileUrl = `/api/v1/files/asr/${filename}` + + return c.json({ + success: true, + url: fileUrl, + filename, + size: file.size, + type: file.type, + }) + } catch (error) { + const err = error as Error + logger.error({ error: err.message, stack: err.stack }, "Error uploading file") + throw error + } +} + +// Serve uploaded ASR files +export const serveASRFile = async (c: Context) => { + const logger = getLogger(Subsystem.Api).child({ module: "serveASRFile" }) + + try { + const filename = c.req.param("filename") + + if (!filename) { + throw new HTTPException(400, { message: "Filename is required" }) + } + + const uploadsDir = join(process.cwd(), "uploads", "asr") + const filePath = join(uploadsDir, filename) + + // Normalize and ensure no path traversal outside uploadsDir + const normalizedPath = path.normalize(filePath) + const normalizedRoot = path.normalize(uploadsDir) + path.sep + if (!normalizedPath.startsWith(normalizedRoot)) { + logger.warn({ filename }, "Path traversal attempt detected") + throw new HTTPException(400, { message: "Invalid filename" }) + } + + // Check if file exists + try { + await fs.access(normalizedPath) + } catch { + throw new HTTPException(404, { message: "File not found" }) + } + + logger.info({ filename }, "Serving ASR file") + + // Read and serve file + const fileBuffer = await fs.readFile(normalizedPath) + + // Determine content type from extension + const ext = path.extname(filename).toLowerCase() + const contentTypeMap: Record = { + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".m4a": "audio/mp4", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".mp4": "video/mp4", + ".webm": "video/webm", + } + const contentType = contentTypeMap[ext] || "application/octet-stream" + + return c.newResponse(fileBuffer as any, 200, { + "Content-Type": contentType, + "Content-Length": fileBuffer.length.toString(), + }) + } catch (error) { + const err = error as Error + logger.error( + { error: err.message, stack: err.stack, filename: c.req.param("filename") }, + "Error serving ASR file", + ) + throw error + } +} diff --git a/server/asr-sd/whisper_diarization.py b/server/asr-sd/whisper_diarization.py new file mode 100644 index 000000000..e81f9edd6 --- /dev/null +++ b/server/asr-sd/whisper_diarization.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +""" +Simplified ASR + Speaker Diarization Script +Only handles Whisper transcription and Pyannote speaker diarization in parallel. +All post-processing (merging, chunking, LLM refinement) is handled in TypeScript. +""" + +import argparse +import json +import os +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import whisper +from pyannote.audio import Pipeline +from pyannote.core import Annotation + + +class ParallelWhisperDiarization: + """Runs Whisper and Pyannote in parallel, merges results deterministically.""" + + def __init__( + self, + whisper_model: str = "turbo", + diarization_model: str = "pyannote/speaker-diarization-3.1", + device: Optional[str] = None, + hf_token: Optional[str] = None, + ): + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + print(f"Using device: {self.device}") + + # Load Whisper model exactly as requested + print(f"Loading Whisper model: {whisper_model}...") + self.whisper_model = whisper.load_model(whisper_model, device=self.device) + + # Load pyannote diarization pipeline + print(f"Loading diarization model: {diarization_model}...") + if hf_token: + self.diarization_pipeline = Pipeline.from_pretrained( + diarization_model, token=hf_token + ) + else: + try: + self.diarization_pipeline = Pipeline.from_pretrained(diarization_model) + except Exception as e: + print("\nError: Pyannote models require HuggingFace authentication.") + print("Please provide a HuggingFace token using --hf-token") + raise e + + # Move diarization to the same device + if self.device == "cuda" and torch.cuda.is_available(): + try: + self.diarization_pipeline.to(torch.device("cuda")) + except Exception: + # Not all components may support .to("cuda"); silently fall back + pass + + def transcribe( + self, + audio_path: str, + language: Optional[str] = None, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + multilingual: bool = False, + **whisper_kwargs, + ) -> Dict: + """Run Whisper and Pyannote in parallel, merge results.""" + print(f"\n{'='*60}") + print(f"PARALLEL PROCESSING: {audio_path}") + print(f"{'='*60}") + + start_time = time.time() + + # Prepare parameters for both tasks + transcribe_options: Dict = {"word_timestamps": True, **whisper_kwargs} + + if multilingual: + print(" → Multilingual mode enabled (code-switching detection)") + transcribe_options["task"] = "transcribe" + else: + transcribe_options["language"] = language + + diarization_options: Dict = {} + if num_speakers is not None: + diarization_options["num_speakers"] = num_speakers + elif min_speakers is not None or max_speakers is not None: + diarization_options["min_speakers"] = min_speakers + diarization_options["max_speakers"] = max_speakers + + # Run Whisper and Pyannote in parallel + print("\nRunning Whisper and Pyannote in parallel...") + print(" Launching parallel workers...") + + whisper_result: Optional[Dict] = None + diarization_result: Optional[Annotation] = None + whisper_time = 0.0 + diarization_time = 0.0 + + executor = ThreadPoolExecutor(max_workers=2) + try: + # Submit both tasks + whisper_future = executor.submit( + self._run_whisper, audio_path, transcribe_options + ) + diarization_future = executor.submit( + self._run_diarization, audio_path, diarization_options + ) + + # Wait for completion and collect results + for future in as_completed([whisper_future, diarization_future]): + try: + if future is whisper_future: + whisper_result, whisper_time = future.result(timeout=None) + print(f" Whisper completed ({whisper_time:.2f}s)") + else: + diarization_result, diarization_time = future.result(timeout=None) + print(f" Pyannote completed ({diarization_time:.2f}s)") + except KeyboardInterrupt: + print("\n\n Interrupted by user! Cancelling workers...") + whisper_future.cancel() + diarization_future.cancel() + executor.shutdown(wait=False, cancel_futures=True) + raise + except Exception as e: + print(f"\nError in parallel worker: {e}") + whisper_future.cancel() + diarization_future.cancel() + executor.shutdown(wait=False, cancel_futures=True) + raise + except KeyboardInterrupt: + print("\n Transcription cancelled by user") + raise + finally: + executor.shutdown(wait=True) + + if whisper_result is None or diarization_result is None: + raise RuntimeError("Whisper or diarization result missing; parallel execution failed.") + + parallel_time = time.time() - start_time + print(f"\n Total parallel time: {parallel_time:.2f}s") + print( + f" Time saved: {(whisper_time + diarization_time - parallel_time):.2f}s" + ) + + # Merge results deterministically + print("\nMerging Whisper transcription with Pyannote speakers...") + merge_start = time.time() + + result = self._merge_results(whisper_result, diarization_result) + + merge_time = time.time() - merge_start + total_time = time.time() - start_time + + print(f" Merge completed ({merge_time:.2f}s)") + print(f"\nProcessing complete! Total time: {total_time:.2f}s") + print( + f" Detected {len(result['speakers'])} speakers: {', '.join(result['speakers'])}" + ) + + # Add timing information + result["timing"] = { + "whisper_time": whisper_time, + "diarization_time": diarization_time, + "parallel_time": parallel_time, + "merge_time": merge_time, + "total_time": total_time, + "time_saved": whisper_time + diarization_time - parallel_time, + } + + return result + + def _run_whisper( + self, audio_path: str, options: Dict + ) -> Tuple[Dict, float]: + """Run Whisper transcription (designed to run in parallel).""" + print(" [Whisper] Starting transcription...") + start_time = time.time() + + # Inference doesn't need gradients + with torch.no_grad(): + result = self.whisper_model.transcribe(audio_path, **options) + + elapsed = time.time() - start_time + return result, elapsed + + def _run_diarization( + self, audio_path: str, options: Dict + ) -> Tuple[Annotation, float]: + """Run Pyannote diarization (designed to run in parallel).""" + print(" [Pyannote] Starting speaker diarization...") + start_time = time.time() + + result = self.diarization_pipeline(audio_path, **options) + + elapsed = time.time() - start_time + return result, elapsed + + def _merge_results( + self, whisper_result: Dict, diarization: Annotation + ) -> Dict: + """Merge Whisper and Pyannote results deterministically.""" + # Extract words with timestamps from Whisper result + word_segments: List[Dict] = [] + + for segment in whisper_result.get("segments", []): + segment_language = segment.get("language", None) + + if "words" not in segment: + # Fallback: if no word timestamps, use segment timestamps + speaker = self._get_speaker_at_timestamp( + diarization, (segment["start"] + segment["end"]) / 2 + ) + word_segments.append( + { + "word": segment["text"], + "start": segment["start"], + "end": segment["end"], + "speaker": speaker, + "probability": segment.get("probability", 1.0), + "language": segment_language, + } + ) + else: + for word_info in segment["words"]: + # Get speaker at the middle of the word (deterministic alignment) + word_middle = (word_info["start"] + word_info["end"]) / 2 + speaker = self._get_speaker_at_timestamp(diarization, word_middle) + + word_segments.append( + { + "word": word_info["word"], + "start": word_info["start"], + "end": word_info["end"], + "speaker": speaker, + "probability": word_info.get("probability", 1.0), + "language": segment_language, + } + ) + + # Group consecutive words by the same speaker into segments + speaker_segments = self._group_by_speaker(word_segments) + + # Get unique speakers + speakers = sorted( + set(ws["speaker"] for ws in word_segments if ws["speaker"]) + ) + + return { + "text": whisper_result["text"], + "segments": speaker_segments, + "word_segments": word_segments, + "language": whisper_result.get("language", "unknown"), + "speakers": speakers, + } + + def _get_speaker_at_timestamp( + self, diarization: Annotation, timestamp: float + ) -> Optional[str]: + """Get speaker at a specific timestamp (deterministic).""" + best_speaker = None + best_dist = None + + for segment, _, speaker in diarization.itertracks(yield_label=True): + if segment.start <= timestamp <= segment.end: + center = (segment.start + segment.end) / 2.0 + dist = abs(center - timestamp) + if best_dist is None or dist < best_dist: + best_dist = dist + best_speaker = speaker + + return best_speaker + + def _smart_join(self, prev_text: str, token: str) -> str: + """Smart spacing around punctuation when joining tokens.""" + if not prev_text: + return token + + # Punctuation that should not have space before them + closing_punct = {",", ".", "!", "?", ":", ";", ")", "]", "}", "'"} + # Punctuation that should not have space after them + opening_punct = {"(", "[", "{", "'"} + + # No space before closing punctuation + if token in closing_punct: + return prev_text + token + + # No space after opening punctuation + if prev_text[-1] in opening_punct: + return prev_text + token + + # Apostrophes within words + if token == "'" and prev_text and prev_text[-1].isalnum(): + return prev_text + token + + # Default: add a space + return prev_text + " " + token + + def _group_by_speaker(self, word_segments: List[Dict]) -> List[Dict]: + """Group consecutive words by the same speaker into segments.""" + if not word_segments: + return [] + + segments: List[Dict] = [] + current_segment = { + "speaker": word_segments[0]["speaker"], + "start": word_segments[0]["start"], + "end": word_segments[0]["end"], + "text": word_segments[0]["word"], + "words": [word_segments[0]], + } + + for word_info in word_segments[1:]: + if word_info["speaker"] == current_segment["speaker"]: + current_segment["text"] = self._smart_join( + current_segment["text"], word_info["word"] + ) + current_segment["end"] = word_info["end"] + current_segment["words"].append(word_info) + else: + segments.append(current_segment) + current_segment = { + "speaker": word_info["speaker"], + "start": word_info["start"], + "end": word_info["end"], + "text": word_info["word"], + "words": [word_info], + } + + segments.append(current_segment) + return segments + + +def main(): + """Command-line interface for Parallel Whisper + Diarization""" + parser = argparse.ArgumentParser( + description="Parallel transcription with Whisper + Pyannote speaker diarization", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("audio", type=str, help="Path to audio file") + + parser.add_argument( + "--whisper-model", + type=str, + default="turbo", + choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3", "turbo"], + help="Whisper model size", + ) + parser.add_argument( + "--diarization-model", + type=str, + default="pyannote/speaker-diarization-3.1", + help="HuggingFace diarization model", + ) + parser.add_argument( + "--hf-token", type=str, default=None, help="HuggingFace token for pyannote models" + ) + parser.add_argument( + "--language", + type=str, + default=None, + help="Language code (e.g., 'en', 'es', 'fr') or None for auto-detect", + ) + parser.add_argument( + "--num-speakers", type=int, default=None, help="Exact number of speakers (if known)" + ) + parser.add_argument( + "--min-speakers", type=int, default=None, help="Minimum number of speakers" + ) + parser.add_argument( + "--max-speakers", type=int, default=None, help="Maximum number of speakers" + ) + parser.add_argument( + "--multilingual", + action="store_true", + help="Enable multilingual/code-switching mode", + ) + parser.add_argument( + "--output", "-o", type=str, default=None, help="Output file path (JSON format)" + ) + parser.add_argument( + "--device", + type=str, + default=None, + choices=["cuda", "cpu"], + help="Device to use for inference", + ) + + args = parser.parse_args() + + if not os.path.exists(args.audio): + print(f"Error: Audio file not found: {args.audio}") + return + + # Initialize pipeline + try: + pipeline = ParallelWhisperDiarization( + whisper_model=args.whisper_model, + diarization_model=args.diarization_model, + device=args.device, + hf_token=args.hf_token, + ) + except Exception as e: + print(f"\nError initializing pipeline: {e}") + return + + # Transcribe with parallel processing + result = pipeline.transcribe( + args.audio, + language=args.language, + num_speakers=args.num_speakers, + min_speakers=args.min_speakers, + max_speakers=args.max_speakers, + multilingual=args.multilingual, + ) + + # Determine output path + if args.output is None: + audio_path = Path(args.audio) + output_path = audio_path.parent / f"{audio_path.stem}_raw.json" + else: + output_path = Path(args.output) + + # Save raw results as JSON + with open(output_path, "w", encoding="utf-8") as f: + json.dump(result, f, indent=2, ensure_ascii=False) + print(f"\nSaved raw transcript: {output_path}") + + # Print summary + print("\n" + "=" * 60) + print("TRANSCRIPTION SUMMARY") + print("=" * 60) + print(f"Language: {result['language']}") + print(f"Speakers detected: {len(result['speakers'])}") + print(f"Speakers: {', '.join(result['speakers'])}") + print(f"\nSegments: {len(result['segments'])}") + + # Print timing information + if "timing" in result: + timing = result["timing"] + print(f"\n{'=' * 60}") + print("PERFORMANCE METRICS") + print(f"{'=' * 60}") + print(f"Whisper time: {timing['whisper_time']:.2f}s") + print(f"Pyannote time: {timing['diarization_time']:.2f}s") + print( + f"Sequential would: {timing['whisper_time'] + timing['diarization_time']:.2f}s" + ) + print(f"Parallel actual: {timing['parallel_time']:.2f}s") + print( + f"Time saved: {timing['time_saved']:.2f}s ({timing['time_saved']/(timing['whisper_time'] + timing['diarization_time'])*100:.1f}%)" + ) + print(f"Merge time: {timing['merge_time']:.2f}s") + print(f"Total time: {timing['total_time']:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/server/queue/asrProcessor.ts b/server/queue/asrProcessor.ts new file mode 100644 index 000000000..f9f44b6fc --- /dev/null +++ b/server/queue/asrProcessor.ts @@ -0,0 +1,327 @@ +import { getLogger } from "@/logger" +import { Subsystem } from "@/types" +import { spawn } from "child_process" +import { promises as fs } from "fs" +import * as path from "path" +import type PgBoss from "pg-boss" +import config from "@/config" +import { + refineTranscript, + mergeConsecutiveSegments, + type TranscriptResult, +} from "@/services/transcriptRefinement" + +const Logger = getLogger(Subsystem.Queue).child({ module: "asrProcessor" }) + +// Paths +const ASR_SD_DIR = path.join(process.cwd(), "asr-sd") +const PYTHON_EXECUTABLE = process.env.PYTHON_PATH || "python3" + +// Helper to format timestamp as HH:MM:SS.mmm +function formatTimestamp(seconds: number): string { + const hours = Math.floor(seconds / 3600) + const minutes = Math.floor((seconds % 3600) / 60) + const secs = seconds % 60 + const secsStr = secs.toFixed(3).padStart(6, "0") // e.g. "01.234" + return `${hours.toString().padStart(2, "0")}:${minutes + .toString() + .padStart(2, "0")}:${secsStr}` +} + +export enum ASRJobType { + Transcribe = "transcribe", +} + +export interface TranscribeJobData { + type: ASRJobType.Transcribe + jobId: string + audioUrl: string + audioPath: string + outputPath: string + whisperModel?: string + language?: string + numSpeakers?: number + minSpeakers?: number + maxSpeakers?: number + multilingual?: boolean + refineWithLLM?: boolean + hfToken?: string + outputFormat?: string +} + +export type ASRJobData = TranscribeJobData + +// Helper function to run Python script +async function runPythonScript( + scriptName: string, + args: string[], +): Promise<{ stdout: string; stderr: string; exitCode: number }> { + return new Promise((resolve, reject) => { + const scriptPath = path.join(ASR_SD_DIR, scriptName) + const pythonProcess = spawn(PYTHON_EXECUTABLE, [scriptPath, ...args]) + + let stdout = "" + let stderr = "" + + pythonProcess.stdout.on("data", (data) => { + const output = data.toString() + stdout += output + Logger.info({ output }, "Python stdout") + }) + + pythonProcess.stderr.on("data", (data) => { + const output = data.toString() + stderr += output + Logger.warn({ output }, "Python stderr") + }) + + pythonProcess.on("close", (code, signal) => { + const exitCode = code ?? -1 // if null (killed by signal), treat as failure + Logger.info({ exitCode, signal }, "Python process exited") + resolve({ + stdout, + stderr, + exitCode, + }) + }) + + pythonProcess.on("error", (error) => { + Logger.error({ error }, "Failed to start Python process") + reject(error) + }) + }) +} + +// Helper function to download file from URL +async function downloadFile(url: string, outputPath: string): Promise { + const response = await fetch(url) + if (!response.ok) { + throw new Error(`Failed to download file: ${response.status} ${response.statusText}`) + } + const buffer = await response.arrayBuffer() + await fs.writeFile(outputPath, Buffer.from(buffer)) +} + +// Helper function to convert audio to optimal format for Whisper/Pyannote +// Target: mono, 16 kHz, 16-bit PCM .wav +async function convertAudioToOptimalFormat( + inputPath: string, + outputPath: string +): Promise { + Logger.info( + { inputPath, outputPath }, + "Converting audio to optimal format (mono, 16kHz, 16-bit PCM WAV)", + ) + + const ffmpegArgs = [ + "-i", + inputPath, // Input file + "-ac", + "1", // Convert to mono (1 audio channel) + "-ar", + "16000", // Sample rate: 16 kHz + "-sample_fmt", + "s16", // 16-bit PCM + "-acodec", + "pcm_s16le", // PCM 16-bit little-endian codec + "-y", // Overwrite output file if exists + outputPath, + ] + + return new Promise((resolve, reject) => { + const ffmpegProcess = spawn("ffmpeg", ffmpegArgs) + + let stderr = "" + + ffmpegProcess.stderr.on("data", (data) => { + stderr += data.toString() + }) + + ffmpegProcess.on("close", (code) => { + if (code === 0) { + Logger.info({ outputPath }, "Audio conversion completed successfully") + resolve() + } else { + Logger.error({ exitCode: code, stderr }, "FFmpeg conversion failed") + reject(new Error(`FFmpeg conversion failed with code ${code}: ${stderr}`)) + } + }) + + ffmpegProcess.on("error", (error) => { + Logger.error({ error }, "Failed to start FFmpeg process") + reject(new Error(`Failed to start FFmpeg: ${error.message}`)) + }) + }) +} + +// Handle transcription jobs +async function handleTranscribeJob( + boss: PgBoss, // currently unused, but kept for future extensions + job: PgBoss.Job, +): Promise { + const data = job.data + Logger.info({ jobId: data.jobId }, "Starting transcription job") + + try { + // Download audio file + Logger.info({ audioUrl: data.audioUrl, audioPath: data.audioPath }, "Downloading audio file") + await downloadFile(data.audioUrl, data.audioPath) + + // Convert audio to optimal format (mono, 16kHz, 16-bit PCM WAV) + const convertedAudioPath = data.audioPath.replace(/\.[^.]+$/, "_converted.wav") + Logger.info({ convertedAudioPath }, "Using converted audio path") + await convertAudioToOptimalFormat(data.audioPath, convertedAudioPath) + + // Build command arguments for simplified Python script (ASR + Diarization only) + const args = [ + convertedAudioPath, // Use converted audio + "--whisper-model", + data.whisperModel || "turbo", + "--output", + data.outputPath + "_raw.json", // Output raw results + ] + + if (data.language) args.push("--language", data.language) + if (data.numSpeakers) args.push("--num-speakers", data.numSpeakers.toString()) + if (data.minSpeakers) args.push("--min-speakers", data.minSpeakers.toString()) + if (data.maxSpeakers) args.push("--max-speakers", data.maxSpeakers.toString()) + + // Default: multilingual ON unless explicitly disabled (but in your pipeline: always true) + const multilingual = data.multilingual !== false + if (multilingual) { + args.push("--multilingual") + } + + // Use HF_TOKEN from environment or from job data + const hfToken = process.env.HF_TOKEN || data.hfToken + if (hfToken) { + args.push("--hf-token", hfToken) + } + + // Run simplified Python script (ASR + Diarization only) + Logger.info({ args }, "Running whisper_diarization.py (ASR + Diarization)") + const result = await runPythonScript("whisper_diarization.py", args) + + if (result.exitCode !== 0) { + throw new Error(`Transcription failed (exit ${result.exitCode}): ${result.stderr}`) + } + + Logger.info( + { jobId: data.jobId, rawOutputPath: data.outputPath + "_raw.json" }, + "ASR + Diarization completed, starting TypeScript post-processing", + ) + + // Read the raw transcript JSON output from Python + const rawJsonPath = data.outputPath + "_raw.json" + const rawTranscriptData = await fs.readFile(rawJsonPath, "utf-8") + const rawTranscript: TranscriptResult = JSON.parse(rawTranscriptData) + + // Decide whether to refine with LLM (default true if undefined) + const shouldRefine = data.refineWithLLM !== false + + let finalTranscript: TranscriptResult = rawTranscript + + if (shouldRefine) { + Logger.info({ jobId: data.jobId }, "Starting LLM refinement in TypeScript") + + // Check which LLM provider is configured + const { defaultBestModel } = config + + if (!defaultBestModel) { + Logger.warn( + { + jobId: data.jobId, + }, + "No LLM provider configured for ASR refinement. Skipping refinement and using raw transcript.", + ) + } else { + try { + // Run TypeScript refinement (works with any configured LLM provider) + finalTranscript = await refineTranscript(rawTranscript, { + maxTokens: 200000, + }) + Logger.info({ jobId: data.jobId }, "LLM refinement completed successfully") + } catch (error) { + Logger.error( + { error, jobId: data.jobId }, + "LLM refinement failed, falling back to raw transcript", + ) + finalTranscript = rawTranscript + } + } + } else { + // Even without LLM refinement, merge consecutive segments deterministically + Logger.info( + { jobId: data.jobId }, + "LLM refinement disabled, merging consecutive segments without refinement", + ) + finalTranscript = { + ...rawTranscript, + segments: mergeConsecutiveSegments(rawTranscript.segments), + } + } + + // Save final results in requested format(s) + const outputFormat = data.outputFormat || "json" + const suffix = shouldRefine ? "_refined" : "_merged" + + if (outputFormat === "json" || outputFormat === "all") { + const jsonPath = data.outputPath + suffix + ".json" + await fs.writeFile(jsonPath, JSON.stringify(finalTranscript, null, 2), "utf-8") + Logger.info({ jobId: data.jobId, path: jsonPath }, "Saved JSON output") + } + + if (outputFormat === "txt" || outputFormat === "all") { + const txtPath = data.outputPath + suffix + ".txt" + const txtContent = finalTranscript.segments + .map((seg) => `[${seg.speaker || "UNKNOWN"}] ${seg.text.trim()}`) + .join("\n") + await fs.writeFile(txtPath, txtContent, "utf-8") + Logger.info({ jobId: data.jobId, path: txtPath }, "Saved TXT output") + } + + if (outputFormat === "srt" || outputFormat === "all") { + const srtPath = data.outputPath + suffix + ".srt" + const srtContent = finalTranscript.segments + .map((seg, idx) => { + const start = formatTimestamp(seg.start).replace(".", ",") + const end = formatTimestamp(seg.end).replace(".", ",") + const text = `[${seg.speaker || "UNKNOWN"}] ${seg.text.trim()}` + return `${idx + 1}\n${start} --> ${end}\n${text}\n` + }) + .join("\n") + await fs.writeFile(srtPath, srtContent, "utf-8") + Logger.info({ jobId: data.jobId, path: srtPath }, "Saved SRT output") + } + + Logger.info({ jobId: data.jobId }, "Transcription pipeline completed successfully") + } catch (error) { + Logger.error({ error, jobId: data.jobId }, "Transcription job failed") + throw error + } +} + +// Main job handler +export async function handleASRJob( + boss: PgBoss, + job: PgBoss.Job, +): Promise { + const data = job.data + Logger.info({ jobId: data.jobId, type: data.type }, "Processing ASR job") + + try { + switch (data.type) { + case ASRJobType.Transcribe: + await handleTranscribeJob(boss, job as PgBoss.Job) + break + default: + Logger.error({ jobId: data.jobId, type: data.type }, "Unknown ASR job type") + throw new Error(`Unknown ASR job type: ${data.type}`) + } + + Logger.info({ jobId: data.jobId }, "ASR job completed") + } catch (error) { + Logger.error({ error, jobId: data.jobId }, "ASR job failed") + throw error + } +} diff --git a/server/queue/index.ts b/server/queue/index.ts index a51d87925..caf0f14ab 100644 --- a/server/queue/index.ts +++ b/server/queue/index.ts @@ -5,6 +5,7 @@ import { } from "@/integrations/google" import { handleToolSync } from "./toolSync" import { handleAttachmentCleanup } from "./attachmentCleanup" +import { handleASRJob } from "./asrProcessor" import { Subsystem, type SaaSJob } from "@/types" // ConnectorType removed import { ConnectorType, SlackEntity } from "@/shared/types" // ConnectorType added import PgBoss from "pg-boss" @@ -55,6 +56,7 @@ export const SyncSlackPerUserQueue = `sync-${Apps.Slack}-${AuthType.OAuth}-per-u export const SyncSlackSchedulerQueue = `sync-${Apps.Slack}-${AuthType.OAuth}-scheduler` export const SyncToolsQueue = `sync-tools` export const CleanupAttachmentsQueue = `cleanup-attachments` +export const ASRQueue = `asr-processing` const TwiceWeekly = `0 0 * * 0,3` const Every10Minutes = `*/10 * * * *` @@ -81,6 +83,7 @@ export const init = async () => { await boss.createQueue(SyncSlackSchedulerQueue) await boss.createQueue(SyncToolsQueue) await boss.createQueue(CleanupAttachmentsQueue) + await boss.createQueue(ASRQueue) // Process ASR jobs sequentially await initWorkers() } @@ -746,6 +749,18 @@ const initWorkers = async () => { ) } }) + + // ASR Queue Worker - Process ASR jobs sequentially + await boss.work(ASRQueue, async ([job]) => { + Logger.info(`Processing ASR job ${job.id}`) + try { + await handleASRJob(boss, job as any) + Logger.info(`ASR job ${job.id} completed successfully`) + } catch (error) { + Logger.error({ error }, `ASR job ${job.id} failed`) + throw error + } + }) } export const ProgressEvent = "progress-event" diff --git a/server/server.ts b/server/server.ts index 3412fc6d4..23ac9e863 100644 --- a/server/server.ts +++ b/server/server.ts @@ -145,6 +145,12 @@ import { inviteToCallSchema, getCallHistorySchema, } from "@/api/calls" +import { + TranscribeAudioApi, + GetJobStatusApi, + transcribeAudioSchema, + getJobStatusSchema, +} from "@/api/asr" import { SendMessageApi, GetConversationApi, @@ -340,6 +346,8 @@ import { handleAttachmentServe, handleThumbnailServe, handleAttachmentDeleteApi, + handleSimpleFileUpload, + serveASRFile, } from "@/api/files" import { z } from "zod" // Ensure z is imported if not already at the top for schemas import { @@ -1321,6 +1329,8 @@ export const AppRoutes = app ) .post("files/upload", handleFileUpload) .post("/files/upload-attachment", handleAttachmentUpload) + .post("/files/upload-simple", handleSimpleFileUpload) + .get("/files/asr/:filename", serveASRFile) .get("/attachments/:fileId", handleAttachmentServe) .get("/attachments/:fileId/thumbnail", handleThumbnailServe) .post( @@ -1554,6 +1564,17 @@ export const AppRoutes = app zValidator("query", getCallHistorySchema), GetCallHistoryApi, ) + // ASR (Automatic Speech Recognition) routes + .post( + "/asr/transcribe", + zValidator("json", transcribeAudioSchema), + TranscribeAudioApi, + ) + .get( + "/asr/job-status", + zValidator("query", getJobStatusSchema), + GetJobStatusApi, + ) // Direct message routes .post("/messages/send", zValidator("json", sendMessageSchema), SendMessageApi) .get( diff --git a/server/services/transcriptRefinement.ts b/server/services/transcriptRefinement.ts new file mode 100644 index 000000000..4c4ac6102 --- /dev/null +++ b/server/services/transcriptRefinement.ts @@ -0,0 +1,474 @@ +import { getLogger } from "@/logger" +import { Subsystem } from "@/types" +import { getProviderByModel } from "@/ai/provider" +import config from "@/config" +import type { LLMProvider } from "@/ai/types" +import { MessageRole } from "@/types" + +const Logger = getLogger(Subsystem.Queue).child({ module: "transcriptRefinement" }) + +// Types matching Python script output +export interface TranscriptSegment { + speaker: string + text: string + start: number + end: number + words?: Array<{ + word: string + start: number + end: number + speaker: string + probability: number + language?: string + }> +} + +export interface TranscriptResult { + text: string + segments: TranscriptSegment[] + word_segments?: Array<{ + word: string + start: number + end: number + speaker: string + probability: number + language?: string + }> + language: string + speakers: string[] + timing?: Record + refinement_applied?: boolean +} + +interface RefinementOptions { + maxTokens?: number + customPrompt?: string + chunkSize?: number +} + +const DEFAULT_REFINEMENT_PROMPT = `You are a transcript refinement expert used in an automated speech pipeline. +Your job is to CLEAN the text but NEVER break alignment. + +NON-NEGOTIABLE RULES (follow in this exact priority order): + +1. DO NOT change the number of segments. If input has N segments, output MUST have N segments. +2. DO NOT change timestamps. Keep each segment's \`start\` and \`end\` exactly as in the input. +3. DO NOT merge, split, reorder, or drop segments. +4. Only change: + - \`speaker\` + - \`text\` + Keep everything else as-is. + +REFINEMENT RULES: + +1. Speaker assignment: + - Use stable, descriptive labels: "Person A", "Person B", "Person C", etc. Assign the same label for the same speaker across all segments in this chunk. + - Do NOT invent real names, even if mentioned in the text. + +2. Spelling & grammar: + - Fix obvious ASR mistakes and casing. + - Keep technical terms, product names, code, and IDs exactly if they look intentional. + +3. Punctuation: + - Add commas, periods, and question marks to make it readable. + - Do not add long stylistic rewrites. + +4. Multilingual / Hindi-English: + - When text is in Hindi or mixed Hindi-English, translate to clear conversational English. + - Preserve cultural/intent nuance ("yaar", "acha", "haan") by using lightweight equivalents ("hey", "okay", "yeah") when needed. + - If translation is ambiguous, keep the original phrase. + +5. Filler words: + - Remove only obvious fillers that don't change meaning ("um", "uh", "like" at the start). + - Keep hesitations that show intent ("I… I don't know", "let me think"). + +OUTPUT FORMAT: + +- Return ONLY a JSON array. +- Each item MUST have exactly these keys: \`speaker\`, \`text\`, \`start\`, \`end\`. +- \`start\` and \`end\` MUST be the original numeric values from input. +- Do NOT wrap the JSON in markdown fences. +- Do NOT add explanations, comments, or metadata.` + +/** + * Heuristically fix UNKNOWN "bridge" segments: + * Pattern: Speaker X -> UNKNOWN (very short) -> Speaker X + * We assume the UNKNOWN segment actually belongs to Speaker X. + */ +export function normalizeUnknownBridgeSegments( + segments: TranscriptSegment[] +): TranscriptSegment[] { + if (!segments || segments.length === 0) return [] + + // Shallow clone so we don't mutate the original array from the caller. + const normalized = segments.map(seg => ({ + ...seg, + text: seg.text ?? "", + words: seg.words ?? [], + })) + + for (let i = 0; i < normalized.length - 2; i++) { + const prev = normalized[i] + const mid = normalized[i + 1] + const next = normalized[i + 2] + + if (!prev || !mid || !next) continue + + const isUnknown = + !mid.speaker || + mid.speaker.toUpperCase() === "UNKNOWN" || + mid.speaker.toUpperCase().startsWith("SPEAKER_") + + const sameSpeaker = prev.speaker && prev.speaker === next.speaker + + if (!isUnknown || !sameSpeaker) continue + + const midText = (mid.text || "").trim() + const midWordCount = midText.length ? midText.split(/\s+/).filter(Boolean).length : 0 + const midDuration = (mid.end ?? 0) - (mid.start ?? 0) + + // Only treat as a bridge if it's very short (e.g., a 1–2 word interjection) + if (midWordCount <= 2 && midDuration <= 1.5) { + mid.speaker = prev.speaker + } + } + + return normalized +} + +/** + * Merge consecutive segments from the same speaker + */ +export function mergeConsecutiveSegments(segments: TranscriptSegment[]): TranscriptSegment[] { + if (!segments || segments.length === 0) { + return [] + } + + const merged: TranscriptSegment[] = [] + let current: TranscriptSegment = { + speaker: segments[0].speaker, + text: segments[0].text?.trim() || "", + start: segments[0].start, + end: segments[0].end, + words: segments[0].words || [], + } + + for (let i = 1; i < segments.length; i++) { + const segment = segments[i] + + if (segment.speaker === current.speaker) { + // Same speaker - merge + const newText = segment.text?.trim() || "" + if (current.text && newText) { + current.text += " " + newText + } else if (newText) { + current.text = newText + } + + current.end = segment.end + if (current.words && segment.words) { + current.words.push(...segment.words) + } + } else { + // Different speaker - save current and start new + merged.push(current) + current = { + speaker: segment.speaker, + text: segment.text?.trim() || "", + start: segment.start, + end: segment.end, + words: segment.words || [], + } + } + } + + // Don't forget the last segment + merged.push(current) + + Logger.info( + `Merged ${segments.length} segments into ${merged.length} (reduction: ${ + segments.length - merged.length + } segments, ${(((segments.length - merged.length) / segments.length) * 100).toFixed(1)}%)` + ) + + return merged +} + +/** + * Estimate token count for a string (rough estimate: 4 chars per token) + */ +function estimateTokens(text: string): number { + return Math.ceil(text.length / 4) +} + +/** + * Create smart chunks that respect segment boundaries and token limits + */ +export function createSmartChunks( + segments: TranscriptSegment[], + maxTokens: number = 200000 +): TranscriptSegment[][] { + const chunks: TranscriptSegment[][] = [] + let currentChunk: TranscriptSegment[] = [] + let currentTokens = 0 + + // Reserve tokens for system prompt and formatting overhead (~2000 tokens) + const usableTokens = Math.max(maxTokens - 2000, 1000) // ensure some sane minimum + + for (const segment of segments) { + // Estimate tokens for this segment + const segmentText = `${segment.speaker || "UNKNOWN"}: ${segment.text || ""}` + const segmentTokens = estimateTokens(segmentText) + 50 // +50 for JSON overhead + + // Check if adding this segment would exceed limit + if (currentTokens + segmentTokens > usableTokens && currentChunk.length > 0) { + // Current chunk is full, start new chunk + chunks.push(currentChunk) + currentChunk = [segment] + currentTokens = segmentTokens + } else { + // Add segment to current chunk + currentChunk.push(segment) + currentTokens += segmentTokens + } + } + + // Don't forget the last chunk + if (currentChunk.length > 0) { + chunks.push(currentChunk) + } + + Logger.info(`Created ${chunks.length} chunks from ${segments.length} segments`) + + return chunks +} + +/** + * Format segments for LLM input + */ +function formatSegmentsForLLM(segments: TranscriptSegment[]): string { + const lines: string[] = [] + for (const seg of segments) { + const speaker = seg.speaker || "UNKNOWN" + const text = seg.text?.trim() || "" + const start = seg.start || 0 + const end = seg.end || 0 + lines.push(`[${speaker}] (${start.toFixed(2)}-${end.toFixed(2)}s): ${text}`) + } + return lines.join("\n") +} + +/** + * Validate and merge refined segments with original timestamps + * LLM is NOT allowed to change structure: same length, same timestamps. + */ +function validateAndMergeSegments( + original: TranscriptSegment[], + refined: unknown +): TranscriptSegment[] { + if (!Array.isArray(refined)) { + Logger.warn("Refined output is not an array, using original segments") + return original + } + + if (refined.length !== original.length) { + Logger.warn( + `Segment count mismatch (original: ${original.length}, refined: ${refined.length}), using original segments` + ) + return original + } + + const merged: TranscriptSegment[] = [] + + for (let i = 0; i < original.length; i++) { + const orig = original[i] + const ref = refined[i] + + if (typeof ref !== "object" || ref === null) { + merged.push(orig) + continue + } + + const refObj = ref as Record + const speaker = typeof refObj.speaker === "string" ? refObj.speaker : orig.speaker + const text = typeof refObj.text === "string" ? refObj.text : orig.text + + merged.push({ + speaker, + text, + start: orig.start, // Always preserve original timestamps + end: orig.end, + words: orig.words, + }) + } + + return merged +} + +/** + * Update word-level segments with refined speaker labels + */ +function updateWordSegments( + wordSegments: TranscriptSegment["words"], + refinedSegments: TranscriptSegment[] +): TranscriptSegment["words"] { + if (!wordSegments || wordSegments.length === 0) { + return [] + } + + const updated = [...wordSegments] + const starts = refinedSegments.map(seg => seg.start) + + for (const word of updated) { + const mid = (word.start + word.end) / 2 + + // Find the segment this word belongs to using a simple forward scan + let matchIndex = 0 + for (let i = 0; i < starts.length; i++) { + if (starts[i] <= mid) { + matchIndex = i + } else { + break + } + } + + // Check neighborhood (±1) for better match + let match: TranscriptSegment | null = null + for (const offset of [-1, 0, 1]) { + const idx = matchIndex + offset + if (idx >= 0 && idx < refinedSegments.length) { + const seg = refinedSegments[idx] + if (seg.start <= mid && mid <= seg.end) { + match = seg + break + } + } + } + + if (match) { + word.speaker = match.speaker + } + } + + return updated +} + +/** + * Refine a single chunk of segments using LLM + */ +async function refineChunk( + provider: LLMProvider, + segments: TranscriptSegment[], + customPrompt?: string +): Promise { + const transcriptText = formatSegmentsForLLM(segments) + const systemPrompt = customPrompt || DEFAULT_REFINEMENT_PROMPT + + const userPrompt = `Refine this transcript chunk: + +${transcriptText} + +Return ONLY a valid JSON array of segments. Each segment must have: speaker, text, start, end. +Do not include any explanation or markdown formatting, just the JSON array.` + + try { + const { defaultBestModel } = config + const response = await provider.converse( + [ + { + role: MessageRole.User, + content: [{ text: userPrompt }], + }, + ], + { + systemPrompt: systemPrompt, + max_new_tokens: 8000, + temperature: 0.3, + modelId: defaultBestModel, + stream: false, + } + ) + + let responseText = (response.text || "").trim() + + // Remove markdown code blocks if present + if (responseText.startsWith("```")) { + const lines = responseText.split("\n") + if (lines.length >= 2 && lines[lines.length - 1].trim().startsWith("```")) { + responseText = lines.slice(1, -1).join("\n") + } + if (responseText.trim().toLowerCase().startsWith("json")) { + responseText = responseText.slice(4).trim() + } + } + + const refinedSegments = JSON.parse(responseText) + return validateAndMergeSegments(segments, refinedSegments) + } catch (error) { + Logger.warn({ error }, "LLM refinement failed for chunk, using original segments") + return segments + } +} + +/** + * Refine entire transcript with LLM chunking and processing + */ +export async function refineTranscript( + result: TranscriptResult, + options: RefinementOptions = {} +): Promise { + const { maxTokens = 200000, customPrompt } = options + + Logger.info("Starting transcript refinement with LLM") + + const segments = result.segments + const wordSegments = result.word_segments || [] + + // STEP 1: Normalize UNKNOWN bridge segments deterministically + Logger.info("Normalizing UNKNOWN bridge segments...") + const normalizedSegments = normalizeUnknownBridgeSegments(segments) + + // STEP 2: Merge consecutive segments from same speaker BEFORE LLM refinement + Logger.info("Merging consecutive speaker segments...") + const mergedInputSegments = mergeConsecutiveSegments(normalizedSegments) + + // STEP 3: Create chunks based on token limit + Logger.info(`Creating chunks (max ${maxTokens} tokens per chunk)...`) + const chunks = createSmartChunks(mergedInputSegments, maxTokens) + + // STEP 4: Get LLM provider + const { defaultBestModel } = config + const provider = getProviderByModel(defaultBestModel) + + // STEP 5: Process each chunk for LLM refinement + const refinedSegments: TranscriptSegment[] = [] + + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i] + Logger.info(`Processing chunk ${i + 1}/${chunks.length} (${chunk.length} segments)...`) + + const refinedChunk = await refineChunk(provider, chunk, customPrompt) + refinedSegments.push(...refinedChunk) + } + + // STEP 6: Update word segments + const refinedWordSegments = updateWordSegments(wordSegments, refinedSegments) + + // STEP 7: Get unique speakers + const speakers = Array.from( + new Set(refinedSegments.map(seg => seg.speaker).filter(Boolean)) + ).sort() + + Logger.info(`Refinement complete! Detected speakers: ${speakers.join(", ")}`) + Logger.info(`Final segments: ${refinedSegments.length}`) + + return { + text: result.text, + segments: refinedSegments, + word_segments: refinedWordSegments, + language: result.language, + speakers, + refinement_applied: true, + timing: result.timing, + } +}