Skip to content

Commit 3339258

Browse files
No public description
PiperOrigin-RevId: 569605914
1 parent 18b1f2e commit 3339258

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Extract properties from each object mask and detect its color."""
16+
from typing import Optional, Union
17+
import numpy as np
18+
import pandas as pd
19+
import skimage.measure
20+
from sklearn.cluster import KMeans
21+
import webcolors
22+
23+
PROPERTIES = [
24+
'area',
25+
'bbox',
26+
'convex_area',
27+
'bbox_area',
28+
'major_axis_length',
29+
'minor_axis_length',
30+
'eccentricity',
31+
'centroid',
32+
]
33+
34+
35+
def extract_properties_and_object_masks(
36+
final_result: dict[str, np.ndarray],
37+
height: int,
38+
width: int,
39+
original_image: np.ndarray,
40+
) -> tuple[list[pd.DataFrame], list[np.ndarray]]:
41+
"""Extract specific properties from given detection masks.
42+
43+
Properties that will be computed includes the area of the masks, bbox
44+
coordinates, area of that bbox, convex length, major_axis_length,
45+
minor_axis_length, eccentricity and centroid.
46+
47+
Args:
48+
final_result: A dictionary containing the num_detections, detection_classes,
49+
detection_scores,detection_boxes,detection_classes_names,
50+
detection_masks_reframed'
51+
height: The height of the original image.
52+
width: The width of the original image.
53+
original_image: The actual image on which the objects were detected.
54+
55+
Returns:
56+
A tuple containing two lists:
57+
1. List of dataframes where each dataframe contains properties for a
58+
detected object.
59+
2. List of ndarrays where each ndarray is a cropped portion of the
60+
original image
61+
corresponding to a detected object.
62+
"""
63+
list_of_df = []
64+
cropped_masks = []
65+
66+
for i, mask in enumerate(final_result['detection_masks_reframed']):
67+
mask = np.where(mask, 1, 0)
68+
df = pd.DataFrame(
69+
skimage.measure.regionprops_table(mask, properties=PROPERTIES)
70+
)
71+
list_of_df.append(df)
72+
73+
bb = final_result['detection_boxes'][0][i]
74+
ymin, xmin, ymax, xmax = (
75+
int(bb[0] * height),
76+
int(bb[1] * width),
77+
int(bb[2] * height),
78+
int(bb[3] * width),
79+
)
80+
mask = np.expand_dims(mask, axis=2)
81+
cropped_object = np.where(
82+
mask[ymin:ymax, xmin:xmax], original_image[ymin:ymax, xmin:xmax], 0
83+
)
84+
cropped_masks.append(cropped_object)
85+
86+
return list_of_df, cropped_masks
87+
88+
89+
def find_dominant_color(
90+
image: np.ndarray, black_threshold: int = 50
91+
) -> tuple[Union[int, str], Union[int, str], Union[int, str]]:
92+
"""Determines the dominant color in a given image.
93+
94+
The function performs the following steps:
95+
Filters out black or near-black pixels based on a threshold.
96+
Uses k-means clustering to identify the dominant color among the remaining
97+
pixels.
98+
99+
Args:
100+
image: An array representation of the image.
101+
black_threshold: pixel value of black color
102+
103+
shape is (height, width, 3) for RGB channels.
104+
black_threshold: The intensity threshold below which pixels
105+
are considered 'black' or near-black. Default is 50.
106+
107+
Returns:
108+
The dominant RGB color in the format (R, G, B). If no non-black
109+
pixels are found, returns ('Na', 'Na', 'Na').
110+
"""
111+
pixels = image.reshape(-1, 3)
112+
113+
# Filter out black pixels based on the threshold
114+
non_black_pixels = pixels[(pixels > black_threshold).any(axis=1)]
115+
116+
if non_black_pixels.size != 0:
117+
kmeans = KMeans(n_clusters=1, n_init=10, random_state=0).fit(
118+
non_black_pixels
119+
)
120+
dominant_color = kmeans.cluster_centers_[0].astype(int)
121+
122+
else:
123+
dominant_color = ['Na', 'Na', 'Na']
124+
return tuple(dominant_color)
125+
126+
127+
def color_difference(color1: int, color2: int) -> Union[float, int]:
128+
"""Computes the squared difference between two color components.
129+
130+
Args:
131+
color1: First color component.
132+
color2: Second color component.
133+
134+
Returns:
135+
The squared difference between the two color components.
136+
"""
137+
return (color1 - color2) ** 2
138+
139+
140+
def est_color(requested_color: tuple[int, int, int]) -> str:
141+
"""Estimates the closest named color for a given RGB color.
142+
143+
The function uses the Euclidean distance in the RGB space to find the closest
144+
match among the CSS3 colors.
145+
146+
Args:
147+
requested_color: The RGB color value for which to find the closest named
148+
color. Expected format is (R, G, B).
149+
150+
Returns:
151+
The name of the closest matching color from the CSS3 predefined colors.
152+
153+
Example: est_color((255, 0, 0))
154+
'red'
155+
"""
156+
min_colors = {}
157+
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
158+
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
159+
rd = color_difference(r_c, requested_color[0])
160+
gd = color_difference(g_c, requested_color[1])
161+
bd = color_difference(b_c, requested_color[2])
162+
min_colors[(rd + gd + bd)] = name
163+
return min_colors[min(min_colors.keys())]
164+
165+
166+
def get_color_name(rgb_color: tuple[int, int, int]) -> Optional[str]:
167+
"""Retrieves the name of a given RGB color.
168+
169+
If the RGB color exactly matches one of the CSS3 predefined colors, it returns
170+
the exact color name.
171+
Otherwise, it estimates the closest matching color name.
172+
173+
Args:
174+
rgb_color: The RGB color value for which to retrieve the name.
175+
176+
Returns:
177+
The name of the color if found, or None if the color is marked as 'Na' or
178+
not found.
179+
180+
Example: get_color_name((255, 0, 0))
181+
'red'
182+
"""
183+
if 'Na' not in rgb_color:
184+
try:
185+
closest_color_name = webcolors.rgb_to_name(rgb_color)
186+
except ValueError:
187+
closest_color_name = est_color(rgb_color)
188+
return closest_color_name
189+
else:
190+
return None

0 commit comments

Comments
 (0)