Skip to content

Commit 5c0617a

Browse files
No public description
PiperOrigin-RevId: 597369513
1 parent 0427687 commit 5c0617a

File tree

3 files changed

+240
-10
lines changed

3 files changed

+240
-10
lines changed

official/projects/waste_identification_ml/model_inference/Inference.ipynb

Lines changed: 62 additions & 4 deletions
Large diffs are not rendered by default.

official/projects/waste_identification_ml/model_inference/color_and_property_extractor.py

Lines changed: 168 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,26 @@
1313
# limitations under the License.
1414

1515
"""Extract properties from each object mask and detect its color."""
16-
from typing import Optional, Union
16+
17+
from typing import Annotated, Literal, TypeVar, Union
18+
1719
import numpy as np
20+
import numpy.typing as npt
1821
import pandas as pd
22+
from skimage import color as skimage_color
1923
import skimage.measure
20-
from sklearn.cluster import KMeans
24+
from sklearn import cluster as sklearn_cluster
25+
from sklearn import neighbors as sklearn_neighbors
2126
import webcolors
2227

28+
DType = TypeVar('DType', bound=np.generic)
29+
# Color representation as numpy array of 3 elements of float64
30+
# Those values could be in different scales like
31+
# RGB ([0.0,255.0], [0.0,255.0], [0.0 to 255.0])
32+
# LAB ([0.0,100], [-128,127], [-128,127])
33+
NColor = Annotated[npt.NDArray[DType], Literal[3]][np.float64]
34+
35+
2336
PROPERTIES = [
2437
'area',
2538
'bbox',
@@ -31,6 +44,68 @@
3144
'centroid',
3245
]
3346

47+
GENERIC_COLORS = [
48+
('black', '#000000'),
49+
('green', '#008000'),
50+
('green', '#00ff00'), # lime
51+
('green', '#3cb371'), # mediumseagreen
52+
('green', '#2E8B57'), # seagreen
53+
('green', '#8FBC8B'), # darkseagreen
54+
('green', '#adff2f'), # olive
55+
('green', '#008080'), # Teal
56+
('green', '#808000'),
57+
('blue', '#000080'), # navy
58+
('blue', '#00008b'), # darkblue
59+
('blue', '#4682b4'), # steelblue
60+
('blue', '#40E0D0'), # turquoise
61+
('blue', '#00FFFF'), # cyan
62+
('blue', '#00ffff'), # aqua
63+
('blue', '#6495ED'), # cornflowerBlue
64+
('blue', '#4169E1'), # royalBlue
65+
('blue', '#87CEFA'), # lightSkyBlue
66+
('blue', '#4682B4'), # steelBlue
67+
('blue', '#B0C4DE'), # lightSteelBlue
68+
('blue', '#87CEEB'), # skyblue
69+
('blue', '#0000CD'), # mediumBlue
70+
('blue', '#0000ff'),
71+
('purple', '#800080'),
72+
('purple', '#9370db'), # mediumpurple
73+
('purple', '#8B008B'), # darkMagenta
74+
('purple', '#4B0082'), # indigo
75+
('red', '#ff0000'),
76+
('red', '#B22222'), # fireBrick
77+
('red', '#DC143C'), # fireBrick
78+
('red', '#8B0000'), # crimson
79+
('red', '#CD5C5C'), # indianred
80+
('red', '#F08080'), # lightCoral
81+
('red', '#FA8072'), # salmon
82+
('red', '#E9967A'), # darkSalmon
83+
('red', '#FFA07A'), # lightSalmon
84+
('gray', '#c0c0c0'), # silver,
85+
('white', '#ffffff'),
86+
('white', '#F5F5DC'), # beige
87+
('white', '#FFFAFA'), # snow
88+
('white', '#F0F8FF'), # aliceBlue
89+
('white', '#FFE4E1'), # mistyRose
90+
('yellow', '#ffff00'),
91+
('yellow', '#ffffe0'), # lightyellow
92+
('yellow', '#8B8000'), # darkyellow,
93+
('orange', '#ffa500'),
94+
('orange', '#ff8c00'), # darkorange
95+
('pink', '#ffc0cb'),
96+
('pink', '#ff00ff'), # fuchsia
97+
('pink', '#C71585'), # mediumVioletRed
98+
('pink', '#DB7093'), # paleVioletRed
99+
('pink', '#FFB6C1'), # lightPink
100+
('pink', '#FF69B4'), # hotPink
101+
('pink', '#FF1493'), # deepPink
102+
('pink', '#BC8F8F'), # rosybrown
103+
('brown', '#a52a2a'),
104+
('brown', '#8b4513'), # saddlebrown
105+
('brown', '#f4a460'), # sandybrown
106+
('brown', '#800000'), # maroon
107+
]
108+
34109

35110
def extract_properties_and_object_masks(
36111
final_result: dict[str, np.ndarray],
@@ -114,9 +189,9 @@ def find_dominant_color(
114189
non_black_pixels = pixels[(pixels > black_threshold).any(axis=1)]
115190

116191
if non_black_pixels.size != 0:
117-
kmeans = KMeans(n_clusters=1, n_init=10, random_state=0).fit(
118-
non_black_pixels
119-
)
192+
kmeans = sklearn_cluster.KMeans(
193+
n_clusters=1, n_init=10, random_state=0
194+
).fit(non_black_pixels)
120195
dominant_color = kmeans.cluster_centers_[0].astype(int)
121196

122197
else:
@@ -163,7 +238,7 @@ def est_color(requested_color: tuple[int, int, int]) -> str:
163238
return min_colors[min(min_colors.keys())]
164239

165240

166-
def get_color_name(rgb_color: tuple[int, int, int]) -> Optional[str]:
241+
def get_color_name(rgb_color: tuple[int, int, int]) -> str | None:
167242
"""Retrieves the name of a given RGB color.
168243
169244
If the RGB color exactly matches one of the CSS3 predefined colors, it returns
@@ -188,3 +263,90 @@ def get_color_name(rgb_color: tuple[int, int, int]) -> Optional[str]:
188263
return closest_color_name
189264
else:
190265
return None
266+
267+
268+
def rgb_int_to_lab(rgb_int_color: tuple[int, int, int]) -> NColor:
269+
"""Convert RGB color to LAB color space.
270+
271+
Args:
272+
rgb_int_color: RGB tuple color e.g. (128,128,128)
273+
274+
Returns:
275+
Numpy array of 3 elements that contains LAB color space.
276+
"""
277+
return skimage_color.rgb2lab(
278+
(rgb_int_color[0] / 255, rgb_int_color[1] / 255, rgb_int_color[2] / 255)
279+
)
280+
281+
282+
def color_distance(
283+
a: tuple[int, int, int], b: tuple[int, int, int]
284+
) -> np.ndarray:
285+
"""The color distance following the ciede2000 formula.
286+
287+
See: https://en.wikipedia.org/wiki/Color_difference#CIEDE2000
288+
289+
Args:
290+
a: Color a
291+
b: Color b
292+
293+
Returns:
294+
The distance between color a and b
295+
"""
296+
return skimage_color.deltaE_ciede2000(a, b, kC=0.6)
297+
298+
299+
def build_color_lab_list(
300+
generic_colors: list[tuple[str, str]]
301+
) -> tuple[npt.NDArray[np.str_], list[NColor]]:
302+
"""Get Simple colors names and lab values.
303+
304+
Args:
305+
generic_colors: List of colors in this format (color_name, rgb_value in hex)
306+
e.g. [ ('black', '#000000'), ('green', '#008000'), ]
307+
308+
Returns:
309+
Numpy array of strings that contains color names
310+
['black', 'green']
311+
List of color lab values in the format of Numpy array of 3 elements
312+
e.g.
313+
[
314+
np.array([0., 0., 0.]),
315+
np.array([ 46.2276577 , -51.69868348, 49.89707556])
316+
]
317+
"""
318+
names: list[str] = []
319+
lab_values = []
320+
for color_name, color_hex in generic_colors:
321+
names.append(color_name)
322+
hex_color = webcolors.hex_to_rgb(color_hex)
323+
lab_values.append(rgb_int_to_lab(hex_color))
324+
color_names = np.array(names)
325+
return color_names, lab_values
326+
327+
328+
def get_generic_color_name(
329+
rgb_colors: list[tuple[int, int, int]],
330+
generic_colors: list[tuple[str, str]] | None = None,
331+
) -> list[str]:
332+
"""Retrieves generic names of given RGB colors.
333+
334+
Estimates the closest matching color name.
335+
336+
Args:
337+
rgb_colors: A list of RGB values for which to retrieve the name.
338+
generic_colors: A list of color names and their RGB values in hex.
339+
340+
Returns:
341+
The list of closest color names.
342+
343+
Example: get_generic_color_name([(255, 0, 0), (0,0,0)])
344+
['red','black']
345+
"""
346+
names, rgb_simple_colors = build_color_lab_list(
347+
generic_colors or GENERIC_COLORS
348+
)
349+
tree = sklearn_neighbors.BallTree(rgb_simple_colors, metric=color_distance)
350+
rgb_query = [*map(rgb_int_to_lab, rgb_colors)]
351+
_, index = tree.query(rgb_query)
352+
return [x[0] for x in names[index]]

official/projects/waste_identification_ml/model_inference/color_and_property_extractor_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ def test_est_color(self):
112112

113113
self.assertEqual(result, 'red')
114114

115+
def test_generic_color(self):
116+
test_colors = np.array(
117+
[(255, 0, 0), (55, 118, 171), (73, 128, 41), (231, 112, 13)]
118+
)
119+
expected_colors = ['red', 'blue', 'green', 'orange']
120+
121+
result = color_and_property_extractor.get_generic_color_name(test_colors)
122+
123+
self.assertEqual(result, expected_colors)
124+
115125

116126
if __name__ == '__main__':
117127
unittest.main()

0 commit comments

Comments
 (0)