Skip to content

Commit 7d892a7

Browse files
committed
fix contrast()
The mean pixel value should be weighted average of the histogram.
1 parent 3afb839 commit 7d892a7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

official/vision/ops/augment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,8 @@ def color(image: tf.Tensor, factor: float) -> tf.Tensor:
844844

845845
def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
846846
"""Equivalent of PIL Contrast."""
847+
image_height = tf.shape(image)[0]
848+
image_width = tf.shape(image)[1]
847849
degenerate = tf.image.rgb_to_grayscale(image)
848850
# Cast before calling tf.histogram.
849851
degenerate = tf.cast(degenerate, tf.int32)
@@ -852,7 +854,8 @@ def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
852854
# and create a constant image size of that value. Use that as the
853855
# blending degenerate target of the original image.
854856
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
855-
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
857+
mean = tf.reduce_sum(
858+
tf.cast(hist, tf.float32) * tf.linspace(0., 255., 256)) / float(image_height * image_width)
856859
degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
857860
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
858861
degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))

0 commit comments

Comments
 (0)