Skip to content

Commit 963aa2d

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds mixed precision support to example models.
- Also fixes a missing cast in continuous_base.py and the corresponding unit test. PiperOrigin-RevId: 429133533 Change-Id: If1f70b684b0d63540a4f3100e1a1074195ec4527
1 parent e4263cd commit 963aa2d

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

models/bls2017.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def call(self, x, training):
107107
"""Computes rate and distortion losses."""
108108
entropy_model = tfc.ContinuousBatchedEntropyModel(
109109
self.prior, coding_rank=3, compression=False)
110+
x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary?
110111
y = self.analysis_transform(x)
111112
y_hat, bits = entropy_model(y, training=training)
112113
x_hat = self.synthesis_transform(y_hat)
@@ -115,6 +116,7 @@ def call(self, x, training):
115116
bpp = tf.reduce_sum(bits) / num_pixels
116117
# Mean squared error across pixels.
117118
mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
119+
mse = tf.cast(mse, bpp.dtype)
118120
# The rate-distortion Lagrangian.
119121
loss = bpp + self.lmbda * mse
120122
return loss, bpp, mse
@@ -166,7 +168,7 @@ def compress(self, x):
166168
"""Compresses an image."""
167169
# Add batch dimension and cast to float.
168170
x = tf.expand_dims(x, 0)
169-
x = tf.cast(x, dtype=tf.float32)
171+
x = tf.cast(x, dtype=self.compute_dtype)
170172
y = self.analysis_transform(x)
171173
# Preserve spatial shapes of both image and latents.
172174
x_shape = tf.shape(x)[1:-1]
@@ -195,7 +197,7 @@ def check_image_size(image, patchsize):
195197

196198
def crop_image(image, patchsize):
197199
image = tf.image.random_crop(image, (patchsize, patchsize, 3))
198-
return tf.cast(image, tf.float32)
200+
return tf.cast(image, tf.keras.mixed_precision.global_policy().compute_dtype)
199201

200202

201203
def get_dataset(name, split, args):
@@ -232,6 +234,8 @@ def get_custom_dataset(split, args):
232234

233235
def train(args):
234236
"""Instantiates and trains the model."""
237+
if args.precision_policy:
238+
tf.keras.mixed_precision.set_global_policy(args.precision_policy)
235239
if args.check_numerics:
236240
tf.debugging.enable_check_numerics()
237241

@@ -391,6 +395,9 @@ def parse_args(argv):
391395
"--preprocess_threads", type=int, default=16,
392396
help="Number of CPU threads to use for parallel decoding of training "
393397
"images.")
398+
train_cmd.add_argument(
399+
"--precision_policy", type=str, default=None,
400+
help="Policy for `tf.keras.mixed_precision` training.")
394401
train_cmd.add_argument(
395402
"--check_numerics", action="store_true",
396403
help="Enable TF support for catching NaN and Inf in tensors.")

models/bmshj2018.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def call(self, x, training):
162162
side_entropy_model = tfc.ContinuousBatchedEntropyModel(
163163
self.hyperprior, coding_rank=3, compression=False)
164164

165+
x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary?
165166
y = self.analysis_transform(x)
166167
z = self.hyper_analysis_transform(abs(y))
167168
z_hat, side_bits = side_entropy_model(z, training=training)
@@ -174,6 +175,7 @@ def call(self, x, training):
174175
bpp = (tf.reduce_sum(bits) + tf.reduce_sum(side_bits)) / num_pixels
175176
# Mean squared error across pixels.
176177
mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
178+
mse = tf.cast(mse, bpp.dtype)
177179
# The rate-distortion Lagrangian.
178180
loss = bpp + self.lmbda * mse
179181
return loss, bpp, mse
@@ -228,7 +230,7 @@ def compress(self, x):
228230
"""Compresses an image."""
229231
# Add batch dimension and cast to float.
230232
x = tf.expand_dims(x, 0)
231-
x = tf.cast(x, dtype=tf.float32)
233+
x = tf.cast(x, dtype=self.compute_dtype)
232234
y = self.analysis_transform(x)
233235
z = self.hyper_analysis_transform(abs(y))
234236
# Preserve spatial shapes of image and latents.
@@ -269,7 +271,7 @@ def check_image_size(image, patchsize):
269271

270272
def crop_image(image, patchsize):
271273
image = tf.image.random_crop(image, (patchsize, patchsize, 3))
272-
return tf.cast(image, tf.float32)
274+
return tf.cast(image, tf.keras.mixed_precision.global_policy().compute_dtype)
273275

274276

275277
def get_dataset(name, split, args):
@@ -306,6 +308,8 @@ def get_custom_dataset(split, args):
306308

307309
def train(args):
308310
"""Instantiates and trains the model."""
311+
if args.precision_policy:
312+
tf.keras.mixed_precision.set_global_policy(args.precision_policy)
309313
if args.check_numerics:
310314
tf.debugging.enable_check_numerics()
311315

@@ -476,6 +480,9 @@ def parse_args(argv):
476480
"--preprocess_threads", type=int, default=16,
477481
help="Number of CPU threads to use for parallel decoding of training "
478482
"images.")
483+
train_cmd.add_argument(
484+
"--precision_policy", type=str, default=None,
485+
help="Policy for `tf.keras.mixed_precision` training.")
479486
train_cmd.add_argument(
480487
"--check_numerics", action="store_true",
481488
help="Enable TF support for catching NaN and Inf in tensors.")

models/ms2020.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def __init__(self, lmbda,
199199

200200
def call(self, x, training):
201201
"""Computes rate and distortion losses."""
202+
x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary?
202203
# Build the encoder (analysis) half of the hierarchical autoencoder.
203204
y = self.analysis_transform(x)
204205
y_shape = tf.shape(y)[1:-1]
@@ -276,6 +277,7 @@ def call(self, x, training):
276277
# Mean squared error across pixels.
277278
# Don't clip or round pixel values while training.
278279
mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
280+
mse = tf.cast(mse, total_bpp.dtype)
279281

280282
# Calculate and return the rate-distortion loss: R + lambda * D.
281283
loss = total_bpp + self.lmbda * mse
@@ -333,7 +335,7 @@ def compress(self, x):
333335
"""Compresses an image."""
334336
# Add batch dimension and cast to float.
335337
x = tf.expand_dims(x, 0)
336-
x = tf.cast(x, dtype=tf.float32)
338+
x = tf.cast(x, dtype=self.compute_dtype)
337339

338340
y_strings = []
339341
x_shape = tf.shape(x)[1:-1]
@@ -439,7 +441,7 @@ def check_image_size(image, patchsize):
439441

440442
def crop_image(image, patchsize):
441443
image = tf.image.random_crop(image, (patchsize, patchsize, 3))
442-
return tf.cast(image, tf.float32)
444+
return tf.cast(image, tf.keras.mixed_precision.global_policy().compute_dtype)
443445

444446

445447
def get_dataset(name, split, args):
@@ -476,6 +478,8 @@ def get_custom_dataset(split, args):
476478

477479
def train(args):
478480
"""Instantiates and trains the model."""
481+
if args.precision_policy:
482+
tf.keras.mixed_precision.set_global_policy(args.precision_policy)
479483
if args.check_numerics:
480484
tf.debugging.enable_check_numerics()
481485

@@ -661,6 +665,9 @@ def parse_args(argv):
661665
"--preprocess_threads", type=int, default=16,
662666
help="Number of CPU threads to use for parallel decoding of training "
663667
"images.")
668+
train_cmd.add_argument(
669+
"--precision_policy", type=str, default=None,
670+
help="Policy for `tf.keras.mixed_precision` training.")
664671
train_cmd.add_argument(
665672
"--check_numerics", action="store_true",
666673
help="Enable TF support for catching NaN and Inf in tensors.")

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,7 @@ def _build_tables(self, prior, precision, offset=None):
233233
CDF table, CDF offsets, CDF lengths.
234234
"""
235235
precision = int(precision)
236-
if offset is None:
237-
offset = 0.
236+
offset = tf.cast(0 if offset is None else offset, prior.dtype)
238237
# Subclasses should have already caught this, but better be safe.
239238
assert not prior.event_shape.rank
240239

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_dtypes_are_correct_with_mixed_precision(self):
198198
tf.keras.mixed_precision.set_global_policy("mixed_float16")
199199
try:
200200
noisy = uniform_noise.NoisyNormal(
201-
loc=tf.constant(0, dtype=tf.float64),
201+
loc=tf.constant(.5, dtype=tf.float64),
202202
scale=tf.constant(1, dtype=tf.float64))
203203
em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
204204
self.assertEqual(em.bottleneck_dtype, tf.float16)

0 commit comments

Comments
 (0)