Skip to content

Commit bc5c5a3

Browse files
No public description
PiperOrigin-RevId: 570225377
1 parent 6a99212 commit bc5c5a3

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""This module provides utilities to normalize image tensors.
16+
"""
17+
from typing import Sequence
18+
import tensorflow as tf
19+
20+
MEAN_NORM = (0.485, 0.456, 0.406)
21+
STDDEV_NORM = (0.229, 0.224, 0.225)
22+
23+
24+
def normalize_image(
25+
image: tf.Tensor,
26+
offset: Sequence[float] = MEAN_NORM,
27+
scale: Sequence[float] = STDDEV_NORM,
28+
) -> tf.Tensor:
29+
"""Normalizes the image to zero mean and unit variance.
30+
31+
If the input image dtype is float, it is expected to either have values in
32+
[0, 1) and offset is MEAN_NORM, or have values in [0, 255] and offset is
33+
MEAN_RGB.
34+
35+
Args:
36+
image: A tf.Tensor in either (1) float dtype with values in range [0, 1) or
37+
[0, 255], or (2) int type with values in range [0, 255].
38+
offset: A tuple of mean values to be subtracted from the image.
39+
scale: A tuple of normalization factors.
40+
41+
Returns:
42+
A normalized image tensor.
43+
"""
44+
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
45+
return normalize_scaled_float_image(image, offset, scale)
46+
47+
48+
def normalize_scaled_float_image(
49+
image: tf.Tensor,
50+
offset: Sequence[float] = MEAN_NORM,
51+
scale: Sequence[float] = STDDEV_NORM,
52+
):
53+
"""Normalizes a scaled float image to zero mean and unit variance.
54+
55+
It assumes the input image is float dtype with values in [0, 1) if offset is
56+
MEAN_NORM, values in [0, 255] if offset is MEAN_RGB.
57+
58+
Args:
59+
image: A tf.Tensor in float32 dtype with values in range [0, 1) or [0, 255].
60+
offset: A tuple of mean values to be subtracted from the image.
61+
scale: A tuple of normalization factors.
62+
63+
Returns:
64+
A normalized image tensor.
65+
"""
66+
offset = tf.constant(offset)
67+
offset = tf.expand_dims(offset, axis=0)
68+
offset = tf.expand_dims(offset, axis=0)
69+
image -= offset
70+
71+
scale = tf.constant(scale)
72+
scale = tf.expand_dims(scale, axis=0)
73+
scale = tf.expand_dims(scale, axis=0)
74+
image /= scale
75+
return image

0 commit comments

Comments
 (0)