Skip to content

Commit 66f9a3d

Browse files
fmassassnl
andauthored
Fix #2354 (#2355) (#2363)
Co-authored-by: Tongzhou Wang <[email protected]>
1 parent 6631b74 commit 66f9a3d

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

torchvision/transforms/functional_tensor.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,22 +288,35 @@ def _blend(img1, img2, ratio):
288288
def _rgb2hsv(img):
289289
r, g, b = img.unbind(0)
290290

291-
maxc, _ = torch.max(img, dim=0)
292-
minc, _ = torch.min(img, dim=0)
291+
maxc = torch.max(img, dim=0).values
292+
minc = torch.min(img, dim=0).values
293+
294+
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
295+
# from happening in the results, because
296+
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
297+
# + H channel has division by `(maxc - minc)`.
298+
#
299+
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
300+
# we don't need to deal with it in case we save the NaN in a buffer in
301+
# backprop, if it is ever supported, but it doesn't hurt to do so.
302+
eqc = maxc == minc
293303

294304
cr = maxc - minc
295-
s = cr / maxc
296-
rc = (maxc - r) / cr
297-
gc = (maxc - g) / cr
298-
bc = (maxc - b) / cr
305+
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
306+
s = cr / torch.where(eqc, maxc.new_ones(()), maxc)
307+
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
308+
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
309+
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
310+
# replacing denominator with 1 when `eqc` is fine.
311+
cr_divisor = torch.where(eqc, maxc.new_ones(()), cr)
312+
rc = (maxc - r) / cr_divisor
313+
gc = (maxc - g) / cr_divisor
314+
bc = (maxc - b) / cr_divisor
299315

300-
t = (maxc != minc)
301-
s = t * s
302316
hr = (maxc == r) * (bc - gc)
303317
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
304318
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
305319
h = (hr + hg + hb)
306-
h = t * h
307320
h = torch.fmod((h / 6.0 + 1.0), 1.0)
308321
return torch.stack((h, s, maxc))
309322

0 commit comments

Comments
 (0)