33import os
44import json
55import logging
6+ from collections import deque
67import sys
78
89import torch
1213 torch .cuda .init ()
1314
1415
15- import torch
16-
17- # Initialize CUDA before any other imports to prevent core dump.
18- if torch .cuda .is_available ():
19- torch .cuda .init ()
20-
21-
2216from twilio .rest import Client
2317from aiohttp import web
2418from aiortc import (
2721 RTCConfiguration ,
2822 RTCIceServer ,
2923 MediaStreamTrack ,
30- RTCDataChannel ,
3124)
3225from aiortc .rtcrtpsender import RTCRtpSender
3326from aiortc .codecs import h264
3427from pipeline import Pipeline
35- from utils import patch_loop_datagram
28+ from utils import patch_loop_datagram , StreamStats , add_prefix_to_app_routes
29+ import time
3630
3731logger = logging .getLogger (__name__ )
3832logging .getLogger ('aiortc.rtcrtpsender' ).setLevel (logging .WARNING )
4438
4539
4640class VideoStreamTrack (MediaStreamTrack ):
41+ """video stream track that processes video frames using a pipeline.
42+
43+ Attributes:
44+ kind (str): The kind of media, which is "video" for this class.
45+ track (MediaStreamTrack): The underlying media stream track.
46+ pipeline (Pipeline): The processing pipeline to apply to each video frame.
47+ """
4748 kind = "video"
48- def __init__ (self , track : MediaStreamTrack , pipeline ):
49+ def __init__ (self , track : MediaStreamTrack , pipeline : Pipeline ):
50+ """Initialize the VideoStreamTrack.
51+
52+ Args:
53+ track: The underlying media stream track.
54+ pipeline: The processing pipeline to apply to each video frame.
55+ """
4956 super ().__init__ ()
5057 self .track = track
5158 self .pipeline = pipeline
59+
60+ self ._lock = asyncio .Lock ()
61+ self ._fps_interval_frame_count = 0
62+ self ._last_fps_calculation_time = None
63+ self ._fps_loop_start_time = time .monotonic ()
64+ self ._fps = 0.0
65+ self ._fps_measurements = deque (maxlen = 60 )
66+ self ._running_event = asyncio .Event ()
67+
5268 asyncio .create_task (self .collect_frames ())
5369
70+ # Start metrics collection tasks.
71+ self ._fps_stats_task = asyncio .create_task (self ._calculate_fps_loop ())
72+
5473 async def collect_frames (self ):
5574 while True :
5675 try :
@@ -60,9 +79,83 @@ async def collect_frames(self):
6079 await self .pipeline .cleanup ()
6180 raise Exception (f"Error collecting video frames: { str (e )} " )
6281
82+ async def _calculate_fps_loop (self ):
83+ """Loop to calculate FPS periodically."""
84+ await self ._running_event .wait ()
85+ self ._fps_loop_start_time = time .monotonic ()
86+ while self .readyState != "ended" :
87+ async with self ._lock :
88+ current_time = time .monotonic ()
89+ if self ._last_fps_calculation_time is not None :
90+ time_diff = current_time - self ._last_fps_calculation_time
91+ self ._fps = self ._fps_interval_frame_count / time_diff
92+ self ._fps_measurements .append (
93+ {
94+ "timestamp" : current_time - self ._fps_loop_start_time ,
95+ "fps" : self ._fps ,
96+ }
97+ ) # Store the FPS measurement with timestamp
98+
99+ # Reset start_time and frame_count for the next interval.
100+ self ._last_fps_calculation_time = current_time
101+ self ._fps_interval_frame_count = 0
102+ await asyncio .sleep (1 ) # Calculate FPS every second.
103+
104+ @property
105+ async def fps (self ) -> float :
106+ """Get the current output frames per second (FPS).
107+
108+ Returns:
109+ The current output FPS.
110+ """
111+ async with self ._lock :
112+ return self ._fps
113+
114+ @property
115+ async def fps_measurements (self ) -> list :
116+ """Get the array of FPS measurements for the last minute.
117+
118+ Returns:
119+ The array of FPS measurements for the last minute.
120+ """
121+ async with self ._lock :
122+ return list (self ._fps_measurements )
123+
124+ @property
125+ async def average_fps (self ) -> float :
126+ """Calculate the average FPS from the measurements taken in the last minute.
127+
128+ Returns:
129+ The average FPS over the last minute.
130+ """
131+ async with self ._lock :
132+ if not self ._fps_measurements :
133+ return 0.0
134+ return sum (
135+ measurement ["fps" ] for measurement in self ._fps_measurements
136+ ) / len (self ._fps_measurements )
137+
138+ @property
139+ async def last_fps_calculation_time (self ) -> float :
140+ """Get the elapsed time since the last FPS calculation.
141+
142+ Returns:
143+ The elapsed time in seconds since the last FPS calculation.
144+ """
145+ async with self ._lock :
146+ return self ._last_fps_calculation_time - self ._fps_loop_start_time
147+
63148 async def recv (self ):
64- return await self .pipeline .get_processed_video_frame ()
65-
149+ processed_frame = await self .pipeline .get_processed_video_frame ()
150+
151+ # Increment frame count for FPS calculation.
152+ async with self ._lock :
153+ self ._fps_interval_frame_count += 1
154+ if not self ._running_event .is_set ():
155+ self ._running_event .set ()
156+
157+ return processed_frame
158+
66159
67160class AudioStreamTrack (MediaStreamTrack ):
68161 kind = "audio"
@@ -168,7 +261,7 @@ def on_datachannel(channel):
168261 async def on_message (message ):
169262 try :
170263 params = json .loads (message )
171-
264+
172265 if params .get ("type" ) == "get_nodes" :
173266 nodes_info = await pipeline .get_nodes_info ()
174267 response = {
@@ -201,6 +294,10 @@ def on_track(track):
201294 tracks ["video" ] = videoTrack
202295 sender = pc .addTrack (videoTrack )
203296
297+ # Store video track in app for stats.
298+ stream_id = track .id
299+ request .app ["video_tracks" ][stream_id ] = videoTrack
300+
204301 codec = "video/H264"
205302 force_codec (pc , sender , codec )
206303 elif track .kind == "audio" :
@@ -211,6 +308,7 @@ def on_track(track):
211308 @track .on ("ended" )
212309 async def on_ended ():
213310 logger .info (f"{ track .kind } track ended" )
311+ request .app ["video_tracks" ].pop (track .id , None )
214312
215313 @pc .on ("connectionstatechange" )
216314 async def on_connectionstatechange ():
@@ -261,6 +359,7 @@ async def on_startup(app: web.Application):
261359 cwd = app ["workspace" ], disable_cuda_malloc = True , gpu_only = True
262360 )
263361 app ["pcs" ] = set ()
362+ app ["video_tracks" ] = {}
264363
265364
266365async def on_shutdown (app : web .Application ):
@@ -301,11 +400,24 @@ async def on_shutdown(app: web.Application):
301400 app .on_startup .append (on_startup )
302401 app .on_shutdown .append (on_shutdown )
303402
304- app .router .add_post ("/offer" , offer )
305- app .router .add_post ("/prompt" , set_prompt )
306403 app .router .add_get ("/" , health )
307404 app .router .add_get ("/health" , health )
308405
406+ # WebRTC signalling and control routes.
407+ app .router .add_post ("/offer" , offer )
408+ app .router .add_post ("/prompt" , set_prompt )
409+
410+ # Add routes for getting stream statistics.
411+ stream_stats = StreamStats (app )
412+ app .router .add_get ("/streams/stats" , stream_stats .collect_all_stream_metrics )
413+ app .router .add_get (
414+ "/stream/{stream_id}/stats" , stream_stats .collect_stream_metrics_by_id
415+ )
416+
417+ # Add hosted platform route prefix.
418+ # NOTE: This ensures that the local and hosted experiences have consistent routes.
419+ add_prefix_to_app_routes (app , "/live" )
420+
309421 def force_print (* args , ** kwargs ):
310422 print (* args , ** kwargs , flush = True )
311423 sys .stdout .flush ()
0 commit comments