Skip to content

Commit 0646fb1

Browse files
Johannes Ballécopybara-github
authored andcommitted
Fixes example models to match the quantization strategy used in the papers.
PiperOrigin-RevId: 424439812 Change-Id: Iadc25866c4d6a5a769d4c9540b0d2f3123c94f4c
1 parent a7998c5 commit 0646fb1

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

models/bmshj2018.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def compress(self, x):
235235
x_shape = tf.shape(x)[1:-1]
236236
y_shape = tf.shape(y)[1:-1]
237237
z_shape = tf.shape(z)[1:-1]
238-
z_hat, _ = self.side_entropy_model(z, training=False)
238+
z_hat = self.side_entropy_model.quantize(z)
239239
indexes = self.hyper_synthesis_transform(z_hat)
240240
indexes = indexes[:, :y_shape[0], :y_shape[1], :]
241241
side_string = self.side_entropy_model.compress(z)

models/ms2020.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def call(self, x, training):
209209

210210
# Build the entropy model for the hyperprior (z).
211211
em_z = tfc.ContinuousBatchedEntropyModel(
212-
self.hyperprior, coding_rank=3, compression=False)
212+
self.hyperprior, coding_rank=3, compression=False,
213+
offset_heuristic=False)
213214

214215
# When training, z_bpp is based on the noisy version of z (z_tilde).
215216
_, z_bits = em_z(z, training=training)
@@ -255,7 +256,7 @@ def call(self, x, training):
255256

256257
# For the synthesis transform, use rounding. Note that quantize()
257258
# overrides the gradient to create a straight-through estimator.
258-
y_hat_slice = em_y.quantize(y_slice, sigma, loc=mu)
259+
y_hat_slice = em_y.quantize(y_slice, loc=mu)
259260

260261
# Add latent residual prediction (LRP).
261262
lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1)
@@ -318,7 +319,8 @@ def fit(self, *args, **kwargs):
318319
retval = super().fit(*args, **kwargs)
319320
# After training, fix range coding tables.
320321
self.em_z = tfc.ContinuousBatchedEntropyModel(
321-
self.hyperprior, coding_rank=3, compression=True)
322+
self.hyperprior, coding_rank=3, compression=True,
323+
offset_heuristic=False)
322324
self.em_y = tfc.LocationScaleIndexedEntropyModel(
323325
tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn,
324326
coding_rank=3, compression=True)

0 commit comments

Comments
 (0)