Skip to content

Commit 838e15f

Browse files
authored
Add source-aware progress reporting to CLI track command (#242)
* Add source-aware progress reporting to CLI track command Display a unified live progress line during frame processing that adapts to the source type. Bounded sources (video files, image directories) show a progress bar with percentage and ETA; unbounded sources (webcams, RTSP streams) show a frame counter with elapsed time. Final state renders as completed, interrupted, or source-lost depending on exit path. * fix(progress): change __exit__ return type from bool to None to satisfy mypy
1 parent 1549682 commit 838e15f

File tree

5 files changed

+611
-2
lines changed

5 files changed

+611
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ dependencies = [
3838
"numpy>=2.0.2",
3939
"supervision>=0.26.1",
4040
"scipy>=1.13.1",
41-
"opencv-python>=4.8.0"
41+
"opencv-python>=4.8.0",
42+
"rich>=13.0.0"
4243
]
4344

4445
[project.optional-dependencies]

test/scripts/test_progress.py

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
# ------------------------------------------------------------------------
2+
# Trackers
3+
# Copyright (c) 2026 Roboflow. All Rights Reserved.
4+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5+
# ------------------------------------------------------------------------
6+
7+
from __future__ import annotations
8+
9+
import time
10+
from io import StringIO
11+
from pathlib import Path
12+
from typing import Callable
13+
from unittest.mock import MagicMock, patch
14+
15+
import cv2
16+
import numpy as np
17+
import pytest
18+
from rich.console import Console
19+
20+
from trackers.scripts.progress import (
21+
_classify_source,
22+
_format_time,
23+
_SourceInfo,
24+
_TrackingProgress,
25+
)
26+
27+
FRAME_WIDTH = 64
28+
FRAME_HEIGHT = 64
29+
FRAME_SIZE = (FRAME_WIDTH, FRAME_HEIGHT)
30+
31+
32+
@pytest.fixture
33+
def video_factory(tmp_path: Path) -> Callable[[int], Path]:
34+
"""Create a small test video with *n* frames."""
35+
36+
def _create(n_frames: int) -> Path:
37+
video_path = tmp_path / f"video_{n_frames}.mp4"
38+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
39+
writer = cv2.VideoWriter(str(video_path), fourcc, 25.0, FRAME_SIZE)
40+
for _ in range(n_frames):
41+
writer.write(np.zeros((FRAME_HEIGHT, FRAME_WIDTH, 3), dtype=np.uint8))
42+
writer.release()
43+
return video_path
44+
45+
return _create
46+
47+
48+
@pytest.fixture
49+
def image_directory_factory(tmp_path: Path) -> Callable[[int], Path]:
50+
"""Create a directory with *n* PNG images."""
51+
52+
def _create(n_frames: int) -> Path:
53+
directory = tmp_path / f"imgdir_{n_frames}"
54+
directory.mkdir(exist_ok=True)
55+
frame = np.zeros((FRAME_HEIGHT, FRAME_WIDTH, 3), dtype=np.uint8)
56+
for i in range(n_frames):
57+
cv2.imwrite(str(directory / f"{i:04d}.png"), frame)
58+
return directory
59+
60+
return _create
61+
62+
63+
def _make_console() -> tuple[Console, StringIO]:
64+
"""Return a Console that writes to a StringIO buffer."""
65+
buf = StringIO()
66+
console = Console(file=buf, force_terminal=True, width=200)
67+
return console, buf
68+
69+
70+
class TestClassifySource:
71+
def test_video_file(self, video_factory: Callable[[int], Path]) -> None:
72+
video_path = video_factory(10)
73+
info = _classify_source(str(video_path))
74+
75+
assert info.source_type == "video"
76+
assert info.total_frames is not None
77+
assert info.total_frames > 0
78+
assert info.fps is not None
79+
assert info.fps > 0
80+
81+
def test_image_directory(
82+
self, image_directory_factory: Callable[[int], Path]
83+
) -> None:
84+
directory = image_directory_factory(7)
85+
info = _classify_source(str(directory))
86+
87+
assert info.source_type == "image_dir"
88+
assert info.total_frames == 7
89+
assert info.fps is None
90+
91+
def test_image_directory_path_object(
92+
self, image_directory_factory: Callable[[int], Path]
93+
) -> None:
94+
directory = image_directory_factory(3)
95+
info = _classify_source(directory)
96+
97+
assert info.source_type == "image_dir"
98+
assert info.total_frames == 3
99+
100+
def test_webcam_from_int(self) -> None:
101+
info = _classify_source(0)
102+
103+
assert info.source_type == "webcam"
104+
assert info.total_frames is None
105+
assert info.fps is None
106+
107+
def test_webcam_from_str(self) -> None:
108+
info = _classify_source("0")
109+
110+
assert info.source_type == "webcam"
111+
assert info.total_frames is None
112+
113+
@pytest.mark.parametrize(
114+
"url",
115+
[
116+
"rtsp://192.168.1.10:554/stream",
117+
"http://example.com/stream.mjpg",
118+
"https://example.com/stream.mjpg",
119+
],
120+
)
121+
def test_stream_url(self, url: str) -> None:
122+
info = _classify_source(url)
123+
124+
assert info.source_type == "stream"
125+
assert info.total_frames is None
126+
assert info.fps is None
127+
128+
def test_video_with_zero_frame_count(self) -> None:
129+
mock_cap = MagicMock()
130+
mock_cap.isOpened.return_value = True
131+
mock_cap.get.side_effect = lambda prop: {
132+
cv2.CAP_PROP_FRAME_COUNT: 0.0,
133+
cv2.CAP_PROP_FPS: 30.0,
134+
}.get(prop, 0.0)
135+
136+
with patch("trackers.scripts.progress.cv2.VideoCapture", return_value=mock_cap):
137+
info = _classify_source("some_video.mp4")
138+
139+
assert info.source_type == "video"
140+
assert info.total_frames is None
141+
mock_cap.release.assert_called_once()
142+
143+
def test_nonexistent_file(self) -> None:
144+
info = _classify_source("/nonexistent/video.mp4")
145+
146+
assert info.source_type == "video"
147+
assert info.total_frames is None
148+
149+
def test_empty_image_directory(self, tmp_path: Path) -> None:
150+
empty_dir = tmp_path / "empty"
151+
empty_dir.mkdir()
152+
153+
info = _classify_source(str(empty_dir))
154+
155+
assert info.source_type == "image_dir"
156+
assert info.total_frames is None
157+
158+
159+
class TestFormatTime:
160+
@pytest.mark.parametrize(
161+
"seconds,expected",
162+
[
163+
(0, "0:00"),
164+
(5, "0:05"),
165+
(65, "1:05"),
166+
(3661, "1:01:01"),
167+
(-1, "--"),
168+
],
169+
)
170+
def test_format_time(self, seconds: float, expected: str) -> None:
171+
assert _format_time(seconds) == expected
172+
173+
174+
class TestBuildLine:
175+
def test_bounded_format(self) -> None:
176+
console, _ = _make_console()
177+
source_info = _SourceInfo(source_type="video", total_frames=100)
178+
progress = _TrackingProgress(source_info, console=console)
179+
progress._start_time = time.monotonic() - 5.0
180+
progress._frames_processed = 50
181+
182+
line = progress._build_line("⠹")
183+
text = line.plain
184+
185+
assert "50 / 100" in text
186+
assert "frames" in text
187+
assert "50%" in text
188+
assert "fps" in text
189+
assert "elapsed" in text
190+
assert "eta" in text
191+
192+
def test_unbounded_format(self) -> None:
193+
console, _ = _make_console()
194+
source_info = _SourceInfo(source_type="webcam")
195+
progress = _TrackingProgress(source_info, console=console)
196+
progress._start_time = time.monotonic() - 5.0
197+
progress._frames_processed = 50
198+
199+
line = progress._build_line("⠹")
200+
text = line.plain
201+
202+
assert "50 / --" in text
203+
assert "frames" in text
204+
assert "--" in text
205+
assert "fps" in text
206+
assert "elapsed" in text
207+
assert "eta --" in text
208+
209+
def test_final_no_eta(self) -> None:
210+
console, _ = _make_console()
211+
source_info = _SourceInfo(source_type="video", total_frames=100)
212+
progress = _TrackingProgress(source_info, console=console)
213+
progress._start_time = time.monotonic() - 5.0
214+
progress._frames_processed = 100
215+
216+
line = progress._build_line("✓", show_eta=False)
217+
text = line.plain
218+
219+
assert "eta" not in text
220+
assert "✓" in text
221+
222+
def test_suffix_appended(self) -> None:
223+
console, _ = _make_console()
224+
source_info = _SourceInfo(source_type="video", total_frames=100)
225+
progress = _TrackingProgress(source_info, console=console)
226+
progress._start_time = time.monotonic() - 5.0
227+
progress._frames_processed = 50
228+
229+
line = progress._build_line("✗", show_eta=False, suffix="(interrupted)")
230+
text = line.plain
231+
232+
assert text.endswith("(interrupted)")
233+
234+
def test_zero_elapsed_no_crash(self) -> None:
235+
console, _ = _make_console()
236+
source_info = _SourceInfo(source_type="video", total_frames=100)
237+
progress = _TrackingProgress(source_info, console=console)
238+
progress._start_time = time.monotonic()
239+
progress._frames_processed = 0
240+
241+
# Should not raise ZeroDivisionError
242+
line = progress._build_line("⠹")
243+
text = line.plain
244+
245+
assert "fps" in text
246+
247+
248+
class TestTrackingProgressLifecycle:
249+
def test_bounded_completed(self) -> None:
250+
console, buf = _make_console()
251+
source_info = _SourceInfo(source_type="video", total_frames=5)
252+
253+
with _TrackingProgress(source_info, console=console) as progress:
254+
for _ in range(5):
255+
progress.update()
256+
progress.complete()
257+
258+
output = buf.getvalue()
259+
assert "✓" in output
260+
assert "(interrupted)" not in output
261+
262+
def test_bounded_interrupted_by_display_quit(self) -> None:
263+
console, buf = _make_console()
264+
source_info = _SourceInfo(source_type="video", total_frames=10)
265+
266+
with _TrackingProgress(source_info, console=console) as progress:
267+
for _ in range(5):
268+
progress.update()
269+
progress.complete(interrupted=True)
270+
271+
output = buf.getvalue()
272+
assert "✗" in output
273+
assert "(interrupted)" in output
274+
275+
def test_bounded_keyboard_interrupt(self) -> None:
276+
console, buf = _make_console()
277+
source_info = _SourceInfo(source_type="video", total_frames=10)
278+
279+
progress = _TrackingProgress(source_info, console=console)
280+
progress.__enter__()
281+
for _ in range(3):
282+
progress.update()
283+
284+
# Simulate KeyboardInterrupt in __exit__
285+
progress.__exit__(KeyboardInterrupt, KeyboardInterrupt(), None)
286+
287+
output = buf.getvalue()
288+
assert "✗" in output
289+
assert "(interrupted)" in output
290+
291+
def test_unbounded_completed(self) -> None:
292+
console, buf = _make_console()
293+
source_info = _SourceInfo(source_type="webcam")
294+
295+
with _TrackingProgress(source_info, console=console) as progress:
296+
for _ in range(20):
297+
progress.update()
298+
progress.complete()
299+
300+
output = buf.getvalue()
301+
assert "✓" in output
302+
303+
def test_unbounded_keyboard_interrupt(self) -> None:
304+
console, buf = _make_console()
305+
source_info = _SourceInfo(source_type="stream")
306+
307+
progress = _TrackingProgress(source_info, console=console)
308+
progress.__enter__()
309+
for _ in range(10):
310+
progress.update()
311+
312+
progress.__exit__(KeyboardInterrupt, KeyboardInterrupt(), None)
313+
314+
output = buf.getvalue()
315+
assert "✓" in output
316+
assert "(interrupted)" not in output
317+
318+
def test_error_shows_source_lost(self) -> None:
319+
console, buf = _make_console()
320+
source_info = _SourceInfo(source_type="stream")
321+
322+
progress = _TrackingProgress(source_info, console=console)
323+
progress.__enter__()
324+
for _ in range(5):
325+
progress.update()
326+
327+
err = RuntimeError("connection lost")
328+
progress.__exit__(RuntimeError, err, None)
329+
330+
output = buf.getvalue()
331+
assert "✗" in output
332+
assert "(source lost)" in output
333+
334+
def test_frames_count_in_output(self) -> None:
335+
console, buf = _make_console()
336+
source_info = _SourceInfo(source_type="image_dir", total_frames=30)
337+
338+
with _TrackingProgress(source_info, console=console) as progress:
339+
for _ in range(30):
340+
progress.update()
341+
progress.complete()
342+
343+
output = buf.getvalue()
344+
assert "30 / 30" in output
345+
346+
def test_unbounded_frames_count_in_output(self) -> None:
347+
console, buf = _make_console()
348+
source_info = _SourceInfo(source_type="webcam")
349+
350+
with _TrackingProgress(source_info, console=console) as progress:
351+
for _ in range(42):
352+
progress.update()
353+
progress.complete()
354+
355+
output = buf.getvalue()
356+
assert "42 / --" in output

0 commit comments

Comments
 (0)