Skip to content

Commit 3f69ed3

Browse files
committed
feat: add stream stats endpoint
This commit adds a new stream stats endpoint which can be used to retrieve the fps metrics in a way that doesn't affect performance.
1 parent 89a7e21 commit 3f69ed3

File tree

2 files changed

+154
-9
lines changed

2 files changed

+154
-9
lines changed

server/app.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
RTCConfiguration,
1313
RTCIceServer,
1414
MediaStreamTrack,
15-
RTCDataChannel,
1615
)
16+
import threading
17+
import av
1718
from aiortc.rtcrtpsender import RTCRtpSender
1819
from aiortc.codecs import h264
1920
from pipeline import Pipeline
20-
from utils import patch_loop_datagram
21+
from utils import patch_loop_datagram, StreamStats
22+
import time
2123

2224
logger = logging.getLogger(__name__)
2325
logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING)
@@ -29,16 +31,82 @@
2931

3032

3133
class VideoStreamTrack(MediaStreamTrack):
34+
"""video stream track that processes video frames using a pipeline.
35+
36+
Attributes:
37+
kind (str): The kind of media, which is "video" for this class.
38+
track (MediaStreamTrack): The underlying media stream track.
39+
pipeline (Pipeline): The processing pipeline to apply to each video frame.
40+
"""
41+
3242
kind = "video"
3343

34-
def __init__(self, track: MediaStreamTrack, pipeline):
44+
def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
45+
"""Initialize the VideoStreamTrack.
46+
47+
Args:
48+
track: The underlying media stream track.
49+
pipeline: The processing pipeline to apply to each video frame.
50+
"""
3551
super().__init__()
3652
self.track = track
3753
self.pipeline = pipeline
38-
39-
async def recv(self):
40-
frame = await self.track.recv()
41-
return await self.pipeline(frame)
54+
self._frame_count = 0
55+
self._start_time = time.monotonic()
56+
self._lock = threading.Lock()
57+
self._fps = 0.0
58+
self._running = True
59+
self._start_fps_thread()
60+
61+
def _start_fps_thread(self):
62+
"""Start a separate thread to calculate FPS periodically."""
63+
self.fps_thread = threading.Thread(target=self._calculate_fps_loop, daemon=True)
64+
self.fps_thread.start()
65+
66+
def _calculate_fps_loop(self):
67+
"""Loop to calculate FPS periodically."""
68+
while self._running:
69+
time.sleep(1) # Calculate FPS every second.
70+
with self._lock:
71+
current_time = time.monotonic()
72+
time_diff = current_time - self._start_time
73+
if time_diff > 0:
74+
self._fps = self._frame_count / time_diff
75+
76+
# Reset start_time and frame_count for the next interval.
77+
self._start_time = current_time
78+
self._frame_count = 0
79+
80+
def stop(self):
81+
"""Stop the FPS calculation thread."""
82+
self._running = False
83+
self.fps_thread.join()
84+
85+
@property
86+
def fps(self) -> float:
87+
"""Get the current output frames per second (FPS).
88+
89+
Returns:
90+
The current output FPS.
91+
"""
92+
with self._lock:
93+
return self._fps
94+
95+
async def recv(self) -> av.VideoFrame:
96+
"""Receive and process a video frame. Called by the WebRTC library when a frame
97+
is received.
98+
99+
Returns:
100+
The processed video frame.
101+
"""
102+
input_frame = await self.track.recv()
103+
processed_frame = await self.pipeline(input_frame)
104+
105+
# Increment frame count for FPS calculation.
106+
with self._lock:
107+
self._frame_count += 1
108+
109+
return processed_frame
42110

43111

44112
def force_codec(pc, sender, forced_codec):
@@ -156,6 +224,10 @@ def on_track(track):
156224
tracks["video"] = videoTrack
157225
sender = pc.addTrack(videoTrack)
158226

227+
# Store video track in app for stats.
228+
stream_id = track.id
229+
request.app["video_tracks"][stream_id] = videoTrack
230+
159231
codec = "video/H264"
160232
force_codec(pc, sender, codec)
161233

@@ -207,6 +279,7 @@ async def on_startup(app: web.Application):
207279
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
208280
)
209281
app["pcs"] = set()
282+
app["video_tracks"] = {}
210283

211284

212285
async def on_shutdown(app: web.Application):
@@ -251,4 +324,9 @@ async def on_shutdown(app: web.Application):
251324
app.router.add_post("/prompt", set_prompt)
252325
app.router.add_get("/", health)
253326

327+
# Add routes for getting stream statistics.
328+
stream_stats = StreamStats(app)
329+
app.router.add_get("/stats", stream_stats.get_stats)
330+
app.router.add_get("/stats/{stream_id}", stream_stats.get_stats_by_id)
331+
254332
web.run_app(app, host=args.host, port=int(args.port))

server/utils.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
"""Utility functions for the server."""
2+
13
import asyncio
24
import random
35
import types
46
import logging
5-
6-
from typing import List, Tuple
7+
import json
8+
from aiohttp import web
9+
from aiortc import MediaStreamTrack
10+
from typing import List, Tuple, Any, Dict
711

812
logger = logging.getLogger(__name__)
913

@@ -48,3 +52,66 @@ async def create_datagram_endpoint(
4852

4953
loop.create_datagram_endpoint = types.MethodType(create_datagram_endpoint, loop)
5054
loop._patch_done = True
55+
56+
57+
class StreamStats:
58+
"""Class to get stream statistics."""
59+
60+
def __init__(self, app: web.Application):
61+
"""Initialize the StreamStats class."""
62+
self._app = app
63+
64+
def get_video_track_stats(self, video_track: MediaStreamTrack) -> Dict[str, Any]:
65+
"""Get statistics for a video track.
66+
67+
Args:
68+
video_track: The VideoStreamTrack instance.
69+
70+
Returns:
71+
A dictionary containing the statistics.
72+
"""
73+
return {
74+
"fps": video_track.fps,
75+
}
76+
77+
async def get_stats(self, _) -> web.Response:
78+
"""Get the current stream statistics for all streams.
79+
80+
Args:
81+
request: The HTTP GET request.
82+
83+
Returns:
84+
The HTTP response containing the statistics.
85+
"""
86+
video_tracks = self._app.get("video_tracks", {})
87+
all_stats = {
88+
stream_id: self.get_video_track_stats(track)
89+
for stream_id, track in video_tracks.items()
90+
}
91+
92+
return web.Response(
93+
content_type="application/json",
94+
text=json.dumps(all_stats),
95+
)
96+
97+
async def get_stats_by_id(self, request: web.Request) -> web.Response:
98+
"""Get the statistics for a specific stream by ID.
99+
100+
Args:
101+
request: The HTTP GET request.
102+
103+
Returns:
104+
The HTTP response containing the statistics.
105+
"""
106+
stream_id = request.match_info.get("stream_id")
107+
video_track = self._app["video_tracks"].get(stream_id)
108+
109+
if video_track:
110+
stats = self.get_video_track_stats(video_track)
111+
else:
112+
stats = {"error": "Stream not found"}
113+
114+
return web.Response(
115+
content_type="application/json",
116+
text=json.dumps(stats),
117+
)

0 commit comments

Comments
 (0)