|
2 | 2 | Contains everything that deals with image color extraction. |
3 | 3 | """ |
4 | 4 |
|
5 | | -from pathlib import Path |
6 | | - |
| 5 | +import os |
7 | 6 | import colorgram |
| 7 | +from pathlib import Path |
| 8 | +from typing import Callable, Generator |
| 9 | +from swingmusic.utils.progressbar import tqdm |
| 10 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
8 | 11 |
|
9 | 12 | from swingmusic import settings |
10 | | - |
11 | | -from swingmusic.db.userdata import LibDataTable |
12 | 13 | from swingmusic.logger import log |
13 | 14 | from swingmusic.store.albums import AlbumStore |
| 15 | +from swingmusic.db.userdata import LibDataTable |
14 | 16 | from swingmusic.store.artists import ArtistStore |
15 | | -from swingmusic.utils.progressbar import tqdm |
| 17 | + |
16 | 18 |
|
17 | 19 | def get_image_colors(image: str, count=1) -> list[str]: |
18 | | - """Extracts n number of the most dominant colors from an image.""" |
| 20 | + """ |
| 21 | + Extracts n number of the most dominant colors from an image. |
| 22 | + """ |
19 | 23 | try: |
20 | 24 | colors = sorted(colorgram.extract(image, count), key=lambda c: c.hsl.h) |
21 | 25 | except OSError: |
@@ -44,76 +48,224 @@ def process_color(item_hash: str, is_album=True): |
44 | 48 | return get_image_colors(str(path)) |
45 | 49 |
|
46 | 50 |
|
47 | | -class ProcessAlbumColors: |
| 51 | +def extract_color_worker(item_data: dict) -> dict: |
48 | 52 | """ |
49 | | - Extracts the most dominant color from the album art and saves it to the database. |
| 53 | + Generic worker function for extracting colors in parallel. |
| 54 | + Returns data to main process for batch database operations. |
| 55 | + Works for both albums and artists based on item_data configuration. |
50 | 56 | """ |
| 57 | + hash_field: str = item_data["hash_field"] |
| 58 | + path_func: Callable = item_data["path_func"] |
| 59 | + item_hash: str = item_data[hash_field] |
51 | 60 |
|
52 | | - def __init__(self) -> None: |
53 | | - albums = [a for a in AlbumStore.get_flat_list() if not a.color] |
| 61 | + path = Path(path_func()) / (item_hash + ".webp") |
54 | 62 |
|
55 | | - for album in tqdm(albums, desc="Processing missing album colors"): |
56 | | - albumhash = album.albumhash |
| 63 | + if not path.exists(): |
| 64 | + return {hash_field: item_hash, "color": None, "error": "Image not found"} |
57 | 65 |
|
58 | | - albumrecord = LibDataTable.find_one(albumhash, type="album") |
59 | | - if albumrecord is not None and albumrecord.color is not None: |
60 | | - continue |
| 66 | + colors = get_image_colors(str(path)) |
61 | 67 |
|
62 | | - colors = process_color(albumhash) |
| 68 | + if not colors: |
| 69 | + return { |
| 70 | + hash_field: item_hash, |
| 71 | + "color": None, |
| 72 | + "error": "Color extraction failed", |
| 73 | + } |
63 | 74 |
|
64 | | - if colors is None: |
65 | | - continue |
| 75 | + return {hash_field: item_hash, "color": colors[0], "error": None} |
66 | 76 |
|
67 | | - album = AlbumStore.albummap.get(albumhash) |
68 | 77 |
|
69 | | - if album: |
70 | | - album.set_color(colors[0]) |
| 78 | +class ColorProcessor: |
| 79 | + """ |
| 80 | + Generic color processor for extracting dominant colors from images. |
| 81 | + Uses multiprocessing for parallel color extraction and batch database operations. |
| 82 | + """ |
| 83 | + |
| 84 | + def __init__( |
| 85 | + self, |
| 86 | + item_type: str, |
| 87 | + store: AlbumStore | ArtistStore, |
| 88 | + path_func: Callable, |
| 89 | + hash_field: str, |
| 90 | + ): |
| 91 | + """ |
| 92 | + Initialize the color processor. |
| 93 | +
|
| 94 | + Args: |
| 95 | + item_type: Type of item ("album" or "artist") |
| 96 | + store: Store object (AlbumStore or ArtistStore) |
| 97 | + path_func: Function to get the image path |
| 98 | + hash_field: Name of the hash field ("albumhash" or "artisthash") |
| 99 | + """ |
| 100 | + self.item_type = item_type |
| 101 | + self.store = store |
| 102 | + self.path_func = path_func |
| 103 | + self.hash_field = hash_field |
| 104 | + |
| 105 | + # Read existing colors from database to filter out already processed items |
| 106 | + existing_colors = set() |
| 107 | + for color_data in LibDataTable.get_all_colors(item_type): |
| 108 | + if color_data["color"]: |
| 109 | + existing_colors.add(color_data["itemhash"]) |
| 110 | + |
| 111 | + # Filter items that need color processing |
| 112 | + items_needing_colors = self._get_items_needing_colors(existing_colors) |
| 113 | + |
| 114 | + if not items_needing_colors: |
| 115 | + return |
| 116 | + |
| 117 | + self._process_colors_parallel(items_needing_colors) |
| 118 | + |
| 119 | + def _get_items_needing_colors( |
| 120 | + self, existing_colors: set |
| 121 | + ) -> Generator[dict, None, None]: |
| 122 | + """ |
| 123 | + Generator that yields items needing color processing. |
| 124 | + """ |
| 125 | + for item in self.store.get_flat_list(): |
| 126 | + # Skip if item already has color in memory store |
| 127 | + if item.color: |
| 128 | + continue |
| 129 | + |
| 130 | + # Skip if item already has color in database |
| 131 | + item_hash = getattr(item, self.hash_field) |
| 132 | + if item_hash in existing_colors: |
| 133 | + continue |
71 | 134 |
|
72 | | - # INFO: Write to the database. |
73 | | - if albumrecord is None: |
74 | | - LibDataTable.insert_one( |
| 135 | + yield { |
| 136 | + self.hash_field: item_hash, |
| 137 | + "item_type": self.item_type, |
| 138 | + "path_func": self.path_func, |
| 139 | + "hash_field": self.hash_field, |
| 140 | + } |
| 141 | + |
| 142 | + def _process_colors_parallel(self, items: Generator[dict, None, None]) -> None: |
| 143 | + """ |
| 144 | + Process colors using multiprocessing and batch database operations. |
| 145 | + """ |
| 146 | + items_list = list(items) |
| 147 | + |
| 148 | + if not items_list: |
| 149 | + return |
| 150 | + |
| 151 | + cpus = max(1, (os.cpu_count() or 1) // 2) |
| 152 | + batch_size = 20 # Process results in batches |
| 153 | + |
| 154 | + with ProcessPoolExecutor(max_workers=cpus) as executor: |
| 155 | + # Submit all jobs |
| 156 | + future_to_item = { |
| 157 | + executor.submit(extract_color_worker, item): item for item in items_list |
| 158 | + } |
| 159 | + |
| 160 | + batch = [] |
| 161 | + processed_count = 0 |
| 162 | + |
| 163 | + # Process results as they complete |
| 164 | + progress_bar = tqdm( |
| 165 | + as_completed(future_to_item), |
| 166 | + total=len(items_list), |
| 167 | + desc=f"Processing {self.item_type} colors", |
| 168 | + ) |
| 169 | + |
| 170 | + for future in progress_bar: |
| 171 | + try: |
| 172 | + result = future.result() |
| 173 | + |
| 174 | + if result["color"] is not None: |
| 175 | + batch.append(result) |
| 176 | + |
| 177 | + # Process batch when it reaches batch_size or we're done |
| 178 | + if len(batch) >= batch_size or processed_count + 1 >= len( |
| 179 | + items_list |
| 180 | + ): |
| 181 | + if batch: |
| 182 | + self._process_batch(batch) |
| 183 | + batch = [] |
| 184 | + |
| 185 | + processed_count += 1 |
| 186 | + |
| 187 | + except Exception as e: |
| 188 | + item_data = future_to_item[future] |
| 189 | + item_hash = item_data[self.hash_field] |
| 190 | + log.error(f"Error processing {self.item_type} {item_hash}: {e}") |
| 191 | + |
| 192 | + def _process_batch(self, batch: list[dict]) -> None: |
| 193 | + """ |
| 194 | + Process a batch of color results - update database and memory stores. |
| 195 | + """ |
| 196 | + if not batch: |
| 197 | + return |
| 198 | + |
| 199 | + # Prepare database records |
| 200 | + db_inserts = [] |
| 201 | + db_updates = [] |
| 202 | + |
| 203 | + for result in batch: |
| 204 | + item_hash = result[self.hash_field] |
| 205 | + color = result["color"] |
| 206 | + |
| 207 | + # Check if record exists in database |
| 208 | + existing_record = LibDataTable.find_one(item_hash, type=self.item_type) |
| 209 | + |
| 210 | + if existing_record is None: |
| 211 | + db_inserts.append( |
75 | 212 | { |
76 | | - "itemhash": "album" + albumhash, |
77 | | - "color": colors[0], |
78 | | - "itemtype": "album", |
| 213 | + "itemhash": self.item_type + item_hash, |
| 214 | + "color": color, |
| 215 | + "itemtype": self.item_type, |
79 | 216 | } |
80 | 217 | ) |
81 | 218 | else: |
82 | | - LibDataTable.update_one(albumhash, {"color": colors[0]}) |
| 219 | + db_updates.append( |
| 220 | + {"itemhash": self.item_type + item_hash, "color": color} |
| 221 | + ) |
83 | 222 |
|
| 223 | + # Batch database operations |
| 224 | + if db_inserts: |
| 225 | + LibDataTable.insert_many(db_inserts) |
84 | 226 |
|
85 | | -class ProcessArtistColors: |
86 | | - """ |
87 | | - Extracts the most dominant color from the artist art and saves it to the database. |
88 | | - """ |
| 227 | + if db_updates: |
| 228 | + for update_data in db_updates: |
| 229 | + clean_hash = update_data["itemhash"].replace(self.item_type, "") |
| 230 | + LibDataTable.update_one(clean_hash, {"color": update_data["color"]}) |
89 | 231 |
|
90 | | - def __init__(self) -> None: |
91 | | - all_artists = [a for a in ArtistStore.get_flat_list() if not a.color] |
| 232 | + # Update in-memory store |
| 233 | + store_map = getattr(self.store, f"{self.item_type}map") |
92 | 234 |
|
93 | | - for artist in tqdm(all_artists, desc="Processing missing artist colors"): |
94 | | - artisthash = artist.artisthash |
| 235 | + for result in batch: |
| 236 | + item_hash = result[self.hash_field] |
| 237 | + color = result["color"] |
95 | 238 |
|
96 | | - record = LibDataTable.find_one(artisthash, "artist") |
97 | | - if (record is not None) and (record.color is not None): |
98 | | - continue |
| 239 | + item = store_map.get(item_hash) |
| 240 | + if item: |
| 241 | + item.set_color(color) |
99 | 242 |
|
100 | | - colors = process_color(artisthash, is_album=False) |
101 | 243 |
|
102 | | - if colors is None: |
103 | | - continue |
| 244 | +class ProcessAlbumColors: |
| 245 | + """ |
| 246 | + Extracts the most dominant color from the album art and saves it to the database. |
| 247 | + Uses multiprocessing for parallel color extraction and batch database operations. |
| 248 | + """ |
104 | 249 |
|
105 | | - artist = ArtistStore.artistmap.get(artisthash) |
| 250 | + def __init__(self) -> None: |
| 251 | + ColorProcessor( |
| 252 | + item_type="album", |
| 253 | + store=AlbumStore, |
| 254 | + path_func=settings.Paths.get_sm_thumb_path, |
| 255 | + hash_field="albumhash", |
| 256 | + ) |
106 | 257 |
|
107 | | - if artist: |
108 | | - artist.set_color(colors[0]) |
109 | 258 |
|
110 | | - if record is None: |
111 | | - LibDataTable.insert_one( |
112 | | - { |
113 | | - "itemhash": "artist" + artisthash, |
114 | | - "color": colors[0], |
115 | | - "itemtype": "artist", |
116 | | - } |
117 | | - ) |
118 | | - else: |
119 | | - LibDataTable.update_one("artist" + artisthash, {"color": colors[0]}) |
| 259 | +class ProcessArtistColors: |
| 260 | + """ |
| 261 | + Extracts the most dominant color from the artist art and saves it to the database. |
| 262 | + Uses multiprocessing for parallel color extraction and batch database operations. |
| 263 | + """ |
| 264 | + |
| 265 | + def __init__(self) -> None: |
| 266 | + ColorProcessor( |
| 267 | + item_type="artist", |
| 268 | + store=ArtistStore, |
| 269 | + path_func=settings.Paths.get_sm_artist_img_path, |
| 270 | + hash_field="artisthash", |
| 271 | + ) |
0 commit comments