@@ -288,22 +288,35 @@ def _blend(img1, img2, ratio):
288
288
def _rgb2hsv (img ):
289
289
r , g , b = img .unbind (0 )
290
290
291
- maxc , _ = torch .max (img , dim = 0 )
292
- minc , _ = torch .min (img , dim = 0 )
291
+ maxc = torch .max (img , dim = 0 ).values
292
+ minc = torch .min (img , dim = 0 ).values
293
+
294
+ # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
295
+ # from happening in the results, because
296
+ # + S channel has division by `maxc`, which is zero only if `maxc = minc`
297
+ # + H channel has division by `(maxc - minc)`.
298
+ #
299
+ # Instead of overwriting NaN afterwards, we just prevent it from occuring so
300
+ # we don't need to deal with it in case we save the NaN in a buffer in
301
+ # backprop, if it is ever supported, but it doesn't hurt to do so.
302
+ eqc = maxc == minc
293
303
294
304
cr = maxc - minc
295
- s = cr / maxc
296
- rc = (maxc - r ) / cr
297
- gc = (maxc - g ) / cr
298
- bc = (maxc - b ) / cr
305
+ # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
306
+ s = cr / torch .where (eqc , maxc .new_ones (()), maxc )
307
+ # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
308
+ # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
309
+ # would not matter what values `rc`, `gc`, and `bc` have here, and thus
310
+ # replacing denominator with 1 when `eqc` is fine.
311
+ cr_divisor = torch .where (eqc , maxc .new_ones (()), cr )
312
+ rc = (maxc - r ) / cr_divisor
313
+ gc = (maxc - g ) / cr_divisor
314
+ bc = (maxc - b ) / cr_divisor
299
315
300
- t = (maxc != minc )
301
- s = t * s
302
316
hr = (maxc == r ) * (bc - gc )
303
317
hg = ((maxc == g ) & (maxc != r )) * (2.0 + rc - bc )
304
318
hb = ((maxc != g ) & (maxc != r )) * (4.0 + gc - rc )
305
319
h = (hr + hg + hb )
306
- h = t * h
307
320
h = torch .fmod ((h / 6.0 + 1.0 ), 1.0 )
308
321
return torch .stack ((h , s , maxc ))
309
322
0 commit comments