Skip to content

Commit 0ed48f0

Browse files
NickHerrigpre-commit-ci[bot]BordaCopilot
authored
Add ImageAssets for download (#932)
* Add ImageAssets for download * Update directory for consistency with video assets directory and add soccer image * Fix unaligned table * Add support for `ImageAssets` and extend `download_assets` functionality to handle both image and video assets. * Refactor `Assets` enum initialization and improve `download_assets` logic with enhanced file checks and re-download handling. Add new tests for invalid asset cases and update existing test assertions. * Simplify `download_assets` type handling and refactor `Assets` enum with additional attributes for filename and md5_hash. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jirka <6035284+Borda@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent acbaf4f commit 0ed48f0

File tree

6 files changed

+194
-123
lines changed

6 files changed

+194
-123
lines changed

docs/assets.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ comments: true
44

55
# Assets
66

7-
Supervision offers an assets download utility that allows you to download video files
7+
Supervision offers an assets download utility that allows you to download image and video files
88
that you can use in your demos.
99

1010
<div class="md-typeset">
@@ -18,3 +18,9 @@ that you can use in your demos.
1818
</div>
1919

2020
:::supervision.assets.list.VideoAssets
21+
22+
<div class="md-typeset">
23+
<h2><a href="#supervision.assets.list.ImageAssets">ImageAssets</a></h2>
24+
</div>
25+
26+
:::supervision.assets.list.ImageAssets

src/supervision/assets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from supervision.assets.downloader import download_assets
2-
from supervision.assets.list import VideoAssets
2+
from supervision.assets.list import ImageAssets, VideoAssets
33

4-
__all__ = ["VideoAssets", "download_assets"]
4+
__all__ = ["ImageAssets", "VideoAssets", "download_assets"]

src/supervision/assets/downloader.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from requests import get
99
from tqdm.auto import tqdm
1010

11-
from supervision.assets.list import VIDEO_ASSETS, VideoAssets
11+
from supervision.assets.list import MEDIA_ASSETS, Assets
1212

1313

1414
def is_md5_hash_matching(filename: str, original_md5_hash: str) -> bool:
@@ -35,7 +35,7 @@ def is_md5_hash_matching(filename: str, original_md5_hash: str) -> bool:
3535
return computed_md5_hash.hexdigest() == original_md5_hash
3636

3737

38-
def download_assets(asset_name: VideoAssets | str) -> str:
38+
def download_assets(asset_name: Assets | str) -> str:
3939
"""
4040
Download a specified asset if it doesn't already exist or is corrupted.
4141
@@ -47,42 +47,44 @@ def download_assets(asset_name: VideoAssets | str) -> str:
4747
4848
Example:
4949
```pycon
50-
>>> from supervision.assets import download_assets, VideoAssets
50+
>>> from supervision.assets import download_assets, ImageAssets, VideoAssets
5151
>>> download_assets(VideoAssets.VEHICLES) # doctest: +SKIP
5252
'vehicles.mp4'
5353
54+
>>> download_assets(ImageAssets.PEOPLE_WALKING) # doctest: +SKIP
55+
'people-walking.jpg'
56+
5457
```
5558
"""
5659

57-
filename = asset_name.value if isinstance(asset_name, VideoAssets) else asset_name
58-
59-
if not Path(filename).exists() and filename in VIDEO_ASSETS:
60-
print(f"Downloading {filename} assets \n")
61-
response = get(
62-
VIDEO_ASSETS[filename][0], stream=True, allow_redirects=True, timeout=30
63-
)
64-
response.raise_for_status()
65-
66-
file_size = int(response.headers.get("Content-Length", 0))
67-
folder_path = Path(filename).expanduser().resolve()
68-
folder_path.parent.mkdir(parents=True, exist_ok=True)
69-
70-
with tqdm.wrapattr(
71-
response.raw, "read", total=file_size, desc="", colour="#a351fb"
72-
) as raw_resp:
73-
with folder_path.open("wb") as file:
74-
copyfileobj(raw_resp, file)
75-
76-
elif Path(filename).exists():
77-
if not is_md5_hash_matching(filename, VIDEO_ASSETS[filename][1]):
78-
print("File corrupted. Re-downloading... \n")
79-
os.remove(filename)
80-
return download_assets(filename)
81-
82-
print(f"{filename} asset download complete. \n")
83-
60+
filename = asset_name.filename if isinstance(asset_name, Assets) else asset_name
61+
62+
if filename in MEDIA_ASSETS:
63+
if not Path(filename).exists():
64+
print(f"Downloading {filename} assets \n")
65+
response = get(
66+
MEDIA_ASSETS[filename][0], stream=True, allow_redirects=True, timeout=30
67+
)
68+
response.raise_for_status()
69+
70+
file_size = int(response.headers.get("Content-Length", 0))
71+
folder_path = Path(filename).expanduser().resolve()
72+
folder_path.parent.mkdir(parents=True, exist_ok=True)
73+
74+
with tqdm.wrapattr(
75+
response.raw, "read", total=file_size, desc="", colour="#a351fb"
76+
) as raw_resp:
77+
with folder_path.open("wb") as file:
78+
copyfileobj(raw_resp, file)
79+
else:
80+
if not is_md5_hash_matching(filename, MEDIA_ASSETS[filename][1]):
81+
print("File corrupted. Re-downloading... \n")
82+
os.remove(filename)
83+
return download_assets(filename)
84+
85+
print(f"{filename} asset download complete. \n")
8486
else:
85-
valid_assets = ", ".join(asset.value for asset in VideoAssets)
87+
valid_assets = ", ".join(filename for filename in MEDIA_ASSETS.keys())
8688
raise ValueError(
8789
f"Invalid asset. It should be one of the following: {valid_assets}."
8890
)

src/supervision/assets/list.py

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,31 @@
11
from enum import Enum
22

33
BASE_VIDEO_URL = "https://media.roboflow.com/supervision/video-examples/"
4+
BASE_IMAGE_URL = "https://media.roboflow.com/supervision/image-examples/"
45

56

6-
class VideoAssets(Enum):
7+
class Assets(Enum):
8+
filename: str
9+
md5_hash: str
10+
11+
def __new__(cls, filename: str, md5_hash: str) -> "Assets":
12+
obj = object.__new__(cls)
13+
obj._value_ = filename
14+
obj.filename = filename
15+
obj.md5_hash = md5_hash
16+
return obj
17+
18+
@classmethod
19+
def list(cls) -> list[str]:
20+
return [asset.filename for asset in cls]
21+
22+
23+
class VideoAssets(Assets):
724
"""
8-
Each member of this enum represents a video asset. The value associated with each
9-
member is the filename of the video.
25+
Each member of this class represents a video asset. The value associated with each
26+
member has a filename and hash of the video. File names and links can be seen below.
1027
11-
| Enum Member | Video Filename | Video URL |
28+
| Asset | Video Filename | Video URL |
1229
|------------------------|----------------------------|---------------------------------------------------------------------------------------|
1330
| `VEHICLES` | `vehicles.mp4` | [Link](https://media.roboflow.com/supervision/video-examples/vehicles.mp4) |
1431
| `MILK_BOTTLING_PLANT` | `milk-bottling-plant.mp4` | [Link](https://media.roboflow.com/supervision/video-examples/milk-bottling-plant.mp4) |
@@ -22,61 +39,44 @@ class VideoAssets(Enum):
2239
| `SKIING` | `skiing.mp4` | [Link](https://media.roboflow.com/supervision/video-examples/skiing.mp4) |
2340
""" # noqa: E501 // docs
2441

25-
VEHICLES = "vehicles.mp4"
26-
MILK_BOTTLING_PLANT = "milk-bottling-plant.mp4"
27-
VEHICLES_2 = "vehicles-2.mp4"
28-
GROCERY_STORE = "grocery-store.mp4"
29-
SUBWAY = "subway.mp4"
30-
MARKET_SQUARE = "market-square.mp4"
31-
PEOPLE_WALKING = "people-walking.mp4"
32-
BEACH = "beach-1.mp4"
33-
BASKETBALL = "basketball-1.mp4"
34-
SKIING = "skiing.mp4"
35-
36-
@classmethod
37-
def list(cls) -> list[str]:
38-
return list(map(lambda c: c.value, cls))
39-
40-
41-
VIDEO_ASSETS: dict[str, tuple[str, str]] = {
42-
VideoAssets.VEHICLES.value: (
43-
f"{BASE_VIDEO_URL}{VideoAssets.VEHICLES.value}",
44-
"8155ff4e4de08cfa25f39de96483f918",
45-
),
46-
VideoAssets.VEHICLES_2.value: (
47-
f"{BASE_VIDEO_URL}{VideoAssets.VEHICLES_2.value}",
48-
"830af6fba21ffbf14867a7fea595937b",
49-
),
50-
VideoAssets.MILK_BOTTLING_PLANT.value: (
51-
f"{BASE_VIDEO_URL}{VideoAssets.MILK_BOTTLING_PLANT.value}",
42+
VEHICLES = ("vehicles.mp4", "8155ff4e4de08cfa25f39de96483f918")
43+
MILK_BOTTLING_PLANT = (
44+
"milk-bottling-plant.mp4",
5245
"9e8fb6e883f842a38b3d34267290bdc7",
53-
),
54-
VideoAssets.GROCERY_STORE.value: (
55-
f"{BASE_VIDEO_URL}{VideoAssets.GROCERY_STORE.value}",
56-
"11402e7b861c1980527d3d74cbe3b366",
57-
),
58-
VideoAssets.SUBWAY.value: (
59-
f"{BASE_VIDEO_URL}{VideoAssets.SUBWAY.value}",
60-
"453475750691fb23c56a0cffef089194",
61-
),
62-
VideoAssets.MARKET_SQUARE.value: (
63-
f"{BASE_VIDEO_URL}{VideoAssets.MARKET_SQUARE.value}",
64-
"859179bf4a21f80a8baabfdb2ed716dc",
65-
),
66-
VideoAssets.PEOPLE_WALKING.value: (
67-
f"{BASE_VIDEO_URL}{VideoAssets.PEOPLE_WALKING.value}",
68-
"0574c053c8686c3f1dc0aa3743e45cb9",
69-
),
70-
VideoAssets.BEACH.value: (
71-
f"{BASE_VIDEO_URL}{VideoAssets.BEACH.value}",
72-
"4175d42fec4d450ed081523fd39e0cf8",
73-
),
74-
VideoAssets.BASKETBALL.value: (
75-
f"{BASE_VIDEO_URL}{VideoAssets.BASKETBALL.value}",
76-
"60d94a3c7c47d16f09d342b088012ecc",
77-
),
78-
VideoAssets.SKIING.value: (
79-
f"{BASE_VIDEO_URL}{VideoAssets.SKIING.value}",
80-
"d30987cbab1bbc5934199cdd1b293119",
81-
),
46+
)
47+
VEHICLES_2 = ("vehicles-2.mp4", "830af6fba21ffbf14867a7fea595937b")
48+
GROCERY_STORE = ("grocery-store.mp4", "48608fb4a8981f1c2469fa492adeec9c")
49+
SUBWAY = ("subway.mp4", "453475750691fb23c56a0cffef089194")
50+
MARKET_SQUARE = ("market-square.mp4", "859179bf4a21f80a8baabfdb2ed716dc")
51+
PEOPLE_WALKING = ("people-walking.mp4", "0574c053c8686c3f1dc0aa3743e45cb9")
52+
BEACH = ("beach-1.mp4", "4175d42fec4d450ed081523fd39e0cf8")
53+
BASKETBALL = ("basketball-1.mp4", "60d94a3c7c47d16f09d342b088012ecc")
54+
SKIING = ("skiing.mp4", "d30987cbab1bbc5934199cdd1b293119")
55+
56+
57+
class ImageAssets(Assets):
58+
"""
59+
Each member of this enum represents a image asset. The value associated with each
60+
member is the filename of the image.
61+
62+
| Asset | Image Filename | Video URL |
63+
|--------------------|------------------------|---------------------------------------------------------------------------------------|
64+
| `PEOPLE_WALKING` | `people-walking.jpg` | [Link](https://media.roboflow.com/supervision/image-examples/people-walking.jpg) |
65+
| `SOCCER` | `soccer.jpg` | [Link](https://media.roboflow.com/supervision/image-examples/soccer.jpg) |
66+
67+
""" # noqa: E501 // docs
68+
69+
PEOPLE_WALKING = ("people-walking.jpg", "e6bda00b47f2908eeae7df86ef995dcd")
70+
SOCCER = ("soccer.jpg", "0f5a4b98abf3e3973faf9e9260a7d876")
71+
72+
73+
MEDIA_ASSETS: dict[str, tuple[str, str]] = {
74+
**{
75+
asset.filename: (f"{BASE_VIDEO_URL}{asset.filename}", asset.md5_hash)
76+
for asset in VideoAssets
77+
},
78+
**{
79+
asset.filename: (f"{BASE_IMAGE_URL}{asset.filename}", asset.md5_hash)
80+
for asset in ImageAssets
81+
},
8282
}

tests/assets/test_downloader.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from supervision.assets.downloader import download_assets, is_md5_hash_matching
6-
from supervision.assets.list import VideoAssets
6+
from supervision.assets.list import ImageAssets, VideoAssets
77

88

99
class TestMD5HashMatching:
@@ -51,13 +51,13 @@ def test_already_exists_and_valid(self, mock_exists, mock_md5, mock_print):
5151
@patch("supervision.assets.downloader.is_md5_hash_matching", return_value=False)
5252
@patch("pathlib.Path.exists", return_value=True)
5353
def test_already_exists_but_corrupted(
54-
self, mock_exists, mock_md5, mock_remove, mock_recursive
54+
self, mock_exists, mock_md5, mock_remove, mock_download
5555
):
5656
"""Test download_assets when file exists but is corrupted (re-downloads)."""
5757
filename = "vehicles.mp4"
5858
result = download_assets(filename)
5959
assert result == filename
60-
mock_recursive.assert_called_with(filename)
60+
mock_download.assert_called_with(filename)
6161

6262
@patch("builtins.print")
6363
@patch("pathlib.Path.open", new_callable=mock_open)
@@ -108,14 +108,25 @@ def test_invalid_asset(self, mock_exists):
108108
assert "Invalid asset" in str(exc_info.value)
109109
assert "vehicles.mp4" in str(exc_info.value)
110110

111+
@patch("pathlib.Path.exists", return_value=True)
112+
def test_invalid_asset_when_file_exists(self, mock_exists):
113+
"""Test download_assets with invalid asset name that already exists."""
114+
invalid_filename = "invalid.mp4"
115+
116+
with pytest.raises(ValueError, match="Invalid asset") as exc_info:
117+
download_assets(invalid_filename)
118+
119+
assert "Invalid asset" in str(exc_info.value)
120+
assert "vehicles.mp4" in str(exc_info.value)
121+
111122
@patch("builtins.print")
112123
@patch("pathlib.Path.open", new_callable=mock_open)
113124
@patch("pathlib.Path.mkdir")
114125
@patch("supervision.assets.downloader.copyfileobj")
115126
@patch("supervision.assets.downloader.tqdm")
116127
@patch("supervision.assets.downloader.get")
117128
@patch("pathlib.Path.exists", return_value=False)
118-
def test_with_enum(
129+
def test_with_video_enum(
119130
self,
120131
mock_exists,
121132
mock_get,
@@ -138,4 +149,36 @@ def test_with_enum(
138149
mock_tqdm.wrapattr.return_value.__exit__ = MagicMock()
139150

140151
result = download_assets(asset)
141-
assert result == asset.value
152+
assert result == asset.filename
153+
154+
@patch("builtins.print")
155+
@patch("pathlib.Path.open", new_callable=mock_open)
156+
@patch("pathlib.Path.mkdir")
157+
@patch("supervision.assets.downloader.copyfileobj")
158+
@patch("supervision.assets.downloader.tqdm")
159+
@patch("supervision.assets.downloader.get")
160+
@patch("pathlib.Path.exists", return_value=False)
161+
def test_with_image_enum(
162+
self,
163+
mock_exists,
164+
mock_get,
165+
mock_tqdm,
166+
mock_copyfileobj,
167+
mock_mkdir,
168+
mock_open_file,
169+
mock_print,
170+
):
171+
"""Test download_assets with ImageAssets enum."""
172+
asset = ImageAssets.SOCCER
173+
174+
mock_response = MagicMock()
175+
mock_response.headers = {"Content-Length": "100"}
176+
mock_response.raw = MagicMock()
177+
mock_response.raise_for_status = MagicMock()
178+
mock_get.return_value = mock_response
179+
180+
mock_tqdm.wrapattr.return_value.__enter__ = MagicMock()
181+
mock_tqdm.wrapattr.return_value.__exit__ = MagicMock()
182+
183+
result = download_assets(asset)
184+
assert result == asset.filename

0 commit comments

Comments
 (0)