diff --git a/server/app.py b/server/app.py index 67bf549f..e15343f5 100644 --- a/server/app.py +++ b/server/app.py @@ -3,6 +3,7 @@ import os import json import logging +from collections import deque import sys import torch @@ -12,13 +13,6 @@ torch.cuda.init() -import torch - -# Initialize CUDA before any other imports to prevent core dump. -if torch.cuda.is_available(): - torch.cuda.init() - - from twilio.rest import Client from aiohttp import web from aiortc import ( @@ -27,12 +21,12 @@ RTCConfiguration, RTCIceServer, MediaStreamTrack, - RTCDataChannel, ) from aiortc.rtcrtpsender import RTCRtpSender from aiortc.codecs import h264 from pipeline import Pipeline -from utils import patch_loop_datagram +from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes +import time logger = logging.getLogger(__name__) logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING) @@ -44,13 +38,38 @@ class VideoStreamTrack(MediaStreamTrack): + """video stream track that processes video frames using a pipeline. + + Attributes: + kind (str): The kind of media, which is "video" for this class. + track (MediaStreamTrack): The underlying media stream track. + pipeline (Pipeline): The processing pipeline to apply to each video frame. + """ kind = "video" - def __init__(self, track: MediaStreamTrack, pipeline): + def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): + """Initialize the VideoStreamTrack. + + Args: + track: The underlying media stream track. + pipeline: The processing pipeline to apply to each video frame. + """ super().__init__() self.track = track self.pipeline = pipeline + + self._lock = asyncio.Lock() + self._fps_interval_frame_count = 0 + self._last_fps_calculation_time = None + self._fps_loop_start_time = time.monotonic() + self._fps = 0.0 + self._fps_measurements = deque(maxlen=60) + self._running_event = asyncio.Event() + asyncio.create_task(self.collect_frames()) + # Start metrics collection tasks. + self._fps_stats_task = asyncio.create_task(self._calculate_fps_loop()) + async def collect_frames(self): while True: try: @@ -60,9 +79,83 @@ async def collect_frames(self): await self.pipeline.cleanup() raise Exception(f"Error collecting video frames: {str(e)}") + async def _calculate_fps_loop(self): + """Loop to calculate FPS periodically.""" + await self._running_event.wait() + self._fps_loop_start_time = time.monotonic() + while self.readyState != "ended": + async with self._lock: + current_time = time.monotonic() + if self._last_fps_calculation_time is not None: + time_diff = current_time - self._last_fps_calculation_time + self._fps = self._fps_interval_frame_count / time_diff + self._fps_measurements.append( + { + "timestamp": current_time - self._fps_loop_start_time, + "fps": self._fps, + } + ) # Store the FPS measurement with timestamp + + # Reset start_time and frame_count for the next interval. + self._last_fps_calculation_time = current_time + self._fps_interval_frame_count = 0 + await asyncio.sleep(1) # Calculate FPS every second. + + @property + async def fps(self) -> float: + """Get the current output frames per second (FPS). + + Returns: + The current output FPS. + """ + async with self._lock: + return self._fps + + @property + async def fps_measurements(self) -> list: + """Get the array of FPS measurements for the last minute. + + Returns: + The array of FPS measurements for the last minute. + """ + async with self._lock: + return list(self._fps_measurements) + + @property + async def average_fps(self) -> float: + """Calculate the average FPS from the measurements taken in the last minute. + + Returns: + The average FPS over the last minute. + """ + async with self._lock: + if not self._fps_measurements: + return 0.0 + return sum( + measurement["fps"] for measurement in self._fps_measurements + ) / len(self._fps_measurements) + + @property + async def last_fps_calculation_time(self) -> float: + """Get the elapsed time since the last FPS calculation. + + Returns: + The elapsed time in seconds since the last FPS calculation. + """ + async with self._lock: + return self._last_fps_calculation_time - self._fps_loop_start_time + async def recv(self): - return await self.pipeline.get_processed_video_frame() - + processed_frame = await self.pipeline.get_processed_video_frame() + + # Increment frame count for FPS calculation. + async with self._lock: + self._fps_interval_frame_count += 1 + if not self._running_event.is_set(): + self._running_event.set() + + return processed_frame + class AudioStreamTrack(MediaStreamTrack): kind = "audio" @@ -168,7 +261,7 @@ def on_datachannel(channel): async def on_message(message): try: params = json.loads(message) - + if params.get("type") == "get_nodes": nodes_info = await pipeline.get_nodes_info() response = { @@ -201,6 +294,10 @@ def on_track(track): tracks["video"] = videoTrack sender = pc.addTrack(videoTrack) + # Store video track in app for stats. + stream_id = track.id + request.app["video_tracks"][stream_id] = videoTrack + codec = "video/H264" force_codec(pc, sender, codec) elif track.kind == "audio": @@ -211,6 +308,7 @@ def on_track(track): @track.on("ended") async def on_ended(): logger.info(f"{track.kind} track ended") + request.app["video_tracks"].pop(track.id, None) @pc.on("connectionstatechange") async def on_connectionstatechange(): @@ -261,6 +359,7 @@ async def on_startup(app: web.Application): cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True ) app["pcs"] = set() + app["video_tracks"] = {} async def on_shutdown(app: web.Application): @@ -301,11 +400,24 @@ async def on_shutdown(app: web.Application): app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) - app.router.add_post("/offer", offer) - app.router.add_post("/prompt", set_prompt) app.router.add_get("/", health) app.router.add_get("/health", health) + # WebRTC signalling and control routes. + app.router.add_post("/offer", offer) + app.router.add_post("/prompt", set_prompt) + + # Add routes for getting stream statistics. + stream_stats = StreamStats(app) + app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics) + app.router.add_get( + "/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id + ) + + # Add hosted platform route prefix. + # NOTE: This ensures that the local and hosted experiences have consistent routes. + add_prefix_to_app_routes(app, "/live") + def force_print(*args, **kwargs): print(*args, **kwargs, flush=True) sys.stdout.flush() diff --git a/server/utils.py b/server/utils.py index db263f88..106b607b 100644 --- a/server/utils.py +++ b/server/utils.py @@ -1,9 +1,12 @@ +"""Utility functions for the server.""" import asyncio import random import types import logging - -from typing import List, Tuple +import json +from aiohttp import web +from aiortc import MediaStreamTrack +from typing import List, Tuple, Any, Dict logger = logging.getLogger(__name__) @@ -48,3 +51,84 @@ async def create_datagram_endpoint( loop.create_datagram_endpoint = types.MethodType(create_datagram_endpoint, loop) loop._patch_done = True + + +def add_prefix_to_app_routes(app: web.Application, prefix: str): + """Add a prefix to all routes in the given application. + + Args: + app: The web application whose routes will be prefixed. + prefix: The prefix to add to all routes. + """ + prefix = prefix.rstrip("/") + for route in list(app.router.routes()): + new_path = prefix + route.resource.canonical + app.router.add_route(route.method, new_path, route.handler) + + +class StreamStats: + """Handles real-time video stream statistics collection.""" + + def __init__(self, app: web.Application): + """Initializes the StreamMetrics class. + + Args: + app: The web application instance storing video streams under the + "video_tracks" key. + """ + self._app = app + + async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str, Any]: + """Collects real-time statistics for a video track. + + Args: + video_track: The video stream track instance. + + Returns: + A dictionary containing FPS-related statistics. + """ + return { + "timestamp": await video_track.last_fps_calculation_time, + "fps": await video_track.fps, + "minute_avg_fps": await video_track.average_fps, + "minute_fps_array": await video_track.fps_measurements, + } + + async def collect_all_stream_metrics(self, _) -> web.Response: + """Retrieves real-time metrics for all active video streams. + + Returns: + A JSON response containing FPS statistics for all streams. + """ + video_tracks = self._app.get("video_tracks", {}) + all_stats = { + stream_id: await self.collect_video_metrics(track) + for stream_id, track in video_tracks.items() + } + + return web.Response( + content_type="application/json", + text=json.dumps(all_stats), + ) + + async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response: + """Retrieves real-time metrics for a specific video stream by ID. + + Args: + request: The HTTP request containing the stream ID. + + Returns: + A JSON response with stream metrics or an error message. + """ + stream_id = request.match_info.get("stream_id") + video_track = self._app["video_tracks"].get(stream_id) + + if video_track: + stats = await self.collect_video_metrics(video_track) + else: + stats = {"error": "Stream not found"} + + return web.Response( + content_type="application/json", + text=json.dumps(stats), + )