@@ -844,6 +844,8 @@ def color(image: tf.Tensor, factor: float) -> tf.Tensor:
844
844
845
845
def contrast (image : tf .Tensor , factor : float ) -> tf .Tensor :
846
846
"""Equivalent of PIL Contrast."""
847
+ image_height = tf .shape (image )[0 ]
848
+ image_width = tf .shape (image )[1 ]
847
849
degenerate = tf .image .rgb_to_grayscale (image )
848
850
# Cast before calling tf.histogram.
849
851
degenerate = tf .cast (degenerate , tf .int32 )
@@ -852,7 +854,8 @@ def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
852
854
# and create a constant image size of that value. Use that as the
853
855
# blending degenerate target of the original image.
854
856
hist = tf .histogram_fixed_width (degenerate , [0 , 255 ], nbins = 256 )
855
- mean = tf .reduce_sum (tf .cast (hist , tf .float32 )) / 256.0
857
+ mean = tf .reduce_sum (
858
+ tf .cast (hist , tf .float32 ) * tf .linspace (0. , 255. , 256 )) / float (image_height * image_width )
856
859
degenerate = tf .ones_like (degenerate , dtype = tf .float32 ) * mean
857
860
degenerate = tf .clip_by_value (degenerate , 0.0 , 255.0 )
858
861
degenerate = tf .image .grayscale_to_rgb (tf .cast (degenerate , tf .uint8 ))
0 commit comments