Skip to content

Commit 64fe914

Browse files
committed
rewrite color extraction with multiprocessing 💀
1 parent 4788295 commit 64fe914

File tree

1 file changed

+207
-55
lines changed

1 file changed

+207
-55
lines changed

swingmusic/lib/colorlib.py

Lines changed: 207 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
Contains everything that deals with image color extraction.
33
"""
44

5-
from pathlib import Path
6-
5+
import os
76
import colorgram
7+
from pathlib import Path
8+
from typing import Callable, Generator
9+
from swingmusic.utils.progressbar import tqdm
10+
from concurrent.futures import ProcessPoolExecutor, as_completed
811

912
from swingmusic import settings
10-
11-
from swingmusic.db.userdata import LibDataTable
1213
from swingmusic.logger import log
1314
from swingmusic.store.albums import AlbumStore
15+
from swingmusic.db.userdata import LibDataTable
1416
from swingmusic.store.artists import ArtistStore
15-
from swingmusic.utils.progressbar import tqdm
17+
1618

1719
def get_image_colors(image: str, count=1) -> list[str]:
18-
"""Extracts n number of the most dominant colors from an image."""
20+
"""
21+
Extracts n number of the most dominant colors from an image.
22+
"""
1923
try:
2024
colors = sorted(colorgram.extract(image, count), key=lambda c: c.hsl.h)
2125
except OSError:
@@ -44,76 +48,224 @@ def process_color(item_hash: str, is_album=True):
4448
return get_image_colors(str(path))
4549

4650

47-
class ProcessAlbumColors:
51+
def extract_color_worker(item_data: dict) -> dict:
4852
"""
49-
Extracts the most dominant color from the album art and saves it to the database.
53+
Generic worker function for extracting colors in parallel.
54+
Returns data to main process for batch database operations.
55+
Works for both albums and artists based on item_data configuration.
5056
"""
57+
hash_field: str = item_data["hash_field"]
58+
path_func: Callable = item_data["path_func"]
59+
item_hash: str = item_data[hash_field]
5160

52-
def __init__(self) -> None:
53-
albums = [a for a in AlbumStore.get_flat_list() if not a.color]
61+
path = Path(path_func()) / (item_hash + ".webp")
5462

55-
for album in tqdm(albums, desc="Processing missing album colors"):
56-
albumhash = album.albumhash
63+
if not path.exists():
64+
return {hash_field: item_hash, "color": None, "error": "Image not found"}
5765

58-
albumrecord = LibDataTable.find_one(albumhash, type="album")
59-
if albumrecord is not None and albumrecord.color is not None:
60-
continue
66+
colors = get_image_colors(str(path))
6167

62-
colors = process_color(albumhash)
68+
if not colors:
69+
return {
70+
hash_field: item_hash,
71+
"color": None,
72+
"error": "Color extraction failed",
73+
}
6374

64-
if colors is None:
65-
continue
75+
return {hash_field: item_hash, "color": colors[0], "error": None}
6676

67-
album = AlbumStore.albummap.get(albumhash)
6877

69-
if album:
70-
album.set_color(colors[0])
78+
class ColorProcessor:
79+
"""
80+
Generic color processor for extracting dominant colors from images.
81+
Uses multiprocessing for parallel color extraction and batch database operations.
82+
"""
83+
84+
def __init__(
85+
self,
86+
item_type: str,
87+
store: AlbumStore | ArtistStore,
88+
path_func: Callable,
89+
hash_field: str,
90+
):
91+
"""
92+
Initialize the color processor.
93+
94+
Args:
95+
item_type: Type of item ("album" or "artist")
96+
store: Store object (AlbumStore or ArtistStore)
97+
path_func: Function to get the image path
98+
hash_field: Name of the hash field ("albumhash" or "artisthash")
99+
"""
100+
self.item_type = item_type
101+
self.store = store
102+
self.path_func = path_func
103+
self.hash_field = hash_field
104+
105+
# Read existing colors from database to filter out already processed items
106+
existing_colors = set()
107+
for color_data in LibDataTable.get_all_colors(item_type):
108+
if color_data["color"]:
109+
existing_colors.add(color_data["itemhash"])
110+
111+
# Filter items that need color processing
112+
items_needing_colors = self._get_items_needing_colors(existing_colors)
113+
114+
if not items_needing_colors:
115+
return
116+
117+
self._process_colors_parallel(items_needing_colors)
118+
119+
def _get_items_needing_colors(
120+
self, existing_colors: set
121+
) -> Generator[dict, None, None]:
122+
"""
123+
Generator that yields items needing color processing.
124+
"""
125+
for item in self.store.get_flat_list():
126+
# Skip if item already has color in memory store
127+
if item.color:
128+
continue
129+
130+
# Skip if item already has color in database
131+
item_hash = getattr(item, self.hash_field)
132+
if item_hash in existing_colors:
133+
continue
71134

72-
# INFO: Write to the database.
73-
if albumrecord is None:
74-
LibDataTable.insert_one(
135+
yield {
136+
self.hash_field: item_hash,
137+
"item_type": self.item_type,
138+
"path_func": self.path_func,
139+
"hash_field": self.hash_field,
140+
}
141+
142+
def _process_colors_parallel(self, items: Generator[dict, None, None]) -> None:
143+
"""
144+
Process colors using multiprocessing and batch database operations.
145+
"""
146+
items_list = list(items)
147+
148+
if not items_list:
149+
return
150+
151+
cpus = max(1, (os.cpu_count() or 1) // 2)
152+
batch_size = 20 # Process results in batches
153+
154+
with ProcessPoolExecutor(max_workers=cpus) as executor:
155+
# Submit all jobs
156+
future_to_item = {
157+
executor.submit(extract_color_worker, item): item for item in items_list
158+
}
159+
160+
batch = []
161+
processed_count = 0
162+
163+
# Process results as they complete
164+
progress_bar = tqdm(
165+
as_completed(future_to_item),
166+
total=len(items_list),
167+
desc=f"Processing {self.item_type} colors",
168+
)
169+
170+
for future in progress_bar:
171+
try:
172+
result = future.result()
173+
174+
if result["color"] is not None:
175+
batch.append(result)
176+
177+
# Process batch when it reaches batch_size or we're done
178+
if len(batch) >= batch_size or processed_count + 1 >= len(
179+
items_list
180+
):
181+
if batch:
182+
self._process_batch(batch)
183+
batch = []
184+
185+
processed_count += 1
186+
187+
except Exception as e:
188+
item_data = future_to_item[future]
189+
item_hash = item_data[self.hash_field]
190+
log.error(f"Error processing {self.item_type} {item_hash}: {e}")
191+
192+
def _process_batch(self, batch: list[dict]) -> None:
193+
"""
194+
Process a batch of color results - update database and memory stores.
195+
"""
196+
if not batch:
197+
return
198+
199+
# Prepare database records
200+
db_inserts = []
201+
db_updates = []
202+
203+
for result in batch:
204+
item_hash = result[self.hash_field]
205+
color = result["color"]
206+
207+
# Check if record exists in database
208+
existing_record = LibDataTable.find_one(item_hash, type=self.item_type)
209+
210+
if existing_record is None:
211+
db_inserts.append(
75212
{
76-
"itemhash": "album" + albumhash,
77-
"color": colors[0],
78-
"itemtype": "album",
213+
"itemhash": self.item_type + item_hash,
214+
"color": color,
215+
"itemtype": self.item_type,
79216
}
80217
)
81218
else:
82-
LibDataTable.update_one(albumhash, {"color": colors[0]})
219+
db_updates.append(
220+
{"itemhash": self.item_type + item_hash, "color": color}
221+
)
83222

223+
# Batch database operations
224+
if db_inserts:
225+
LibDataTable.insert_many(db_inserts)
84226

85-
class ProcessArtistColors:
86-
"""
87-
Extracts the most dominant color from the artist art and saves it to the database.
88-
"""
227+
if db_updates:
228+
for update_data in db_updates:
229+
clean_hash = update_data["itemhash"].replace(self.item_type, "")
230+
LibDataTable.update_one(clean_hash, {"color": update_data["color"]})
89231

90-
def __init__(self) -> None:
91-
all_artists = [a for a in ArtistStore.get_flat_list() if not a.color]
232+
# Update in-memory store
233+
store_map = getattr(self.store, f"{self.item_type}map")
92234

93-
for artist in tqdm(all_artists, desc="Processing missing artist colors"):
94-
artisthash = artist.artisthash
235+
for result in batch:
236+
item_hash = result[self.hash_field]
237+
color = result["color"]
95238

96-
record = LibDataTable.find_one(artisthash, "artist")
97-
if (record is not None) and (record.color is not None):
98-
continue
239+
item = store_map.get(item_hash)
240+
if item:
241+
item.set_color(color)
99242

100-
colors = process_color(artisthash, is_album=False)
101243

102-
if colors is None:
103-
continue
244+
class ProcessAlbumColors:
245+
"""
246+
Extracts the most dominant color from the album art and saves it to the database.
247+
Uses multiprocessing for parallel color extraction and batch database operations.
248+
"""
104249

105-
artist = ArtistStore.artistmap.get(artisthash)
250+
def __init__(self) -> None:
251+
ColorProcessor(
252+
item_type="album",
253+
store=AlbumStore,
254+
path_func=settings.Paths.get_sm_thumb_path,
255+
hash_field="albumhash",
256+
)
106257

107-
if artist:
108-
artist.set_color(colors[0])
109258

110-
if record is None:
111-
LibDataTable.insert_one(
112-
{
113-
"itemhash": "artist" + artisthash,
114-
"color": colors[0],
115-
"itemtype": "artist",
116-
}
117-
)
118-
else:
119-
LibDataTable.update_one("artist" + artisthash, {"color": colors[0]})
259+
class ProcessArtistColors:
260+
"""
261+
Extracts the most dominant color from the artist art and saves it to the database.
262+
Uses multiprocessing for parallel color extraction and batch database operations.
263+
"""
264+
265+
def __init__(self) -> None:
266+
ColorProcessor(
267+
item_type="artist",
268+
store=ArtistStore,
269+
path_func=settings.Paths.get_sm_artist_img_path,
270+
hash_field="artisthash",
271+
)

0 commit comments

Comments
 (0)