@@ -209,7 +209,8 @@ def call(self, x, training):
209
209
210
210
# Build the entropy model for the hyperprior (z).
211
211
em_z = tfc .ContinuousBatchedEntropyModel (
212
- self .hyperprior , coding_rank = 3 , compression = False )
212
+ self .hyperprior , coding_rank = 3 , compression = False ,
213
+ offset_heuristic = False )
213
214
214
215
# When training, z_bpp is based on the noisy version of z (z_tilde).
215
216
_ , z_bits = em_z (z , training = training )
@@ -255,7 +256,7 @@ def call(self, x, training):
255
256
256
257
# For the synthesis transform, use rounding. Note that quantize()
257
258
# overrides the gradient to create a straight-through estimator.
258
- y_hat_slice = em_y .quantize (y_slice , sigma , loc = mu )
259
+ y_hat_slice = em_y .quantize (y_slice , loc = mu )
259
260
260
261
# Add latent residual prediction (LRP).
261
262
lrp_support = tf .concat ([mean_support , y_hat_slice ], axis = - 1 )
@@ -318,7 +319,8 @@ def fit(self, *args, **kwargs):
318
319
retval = super ().fit (* args , ** kwargs )
319
320
# After training, fix range coding tables.
320
321
self .em_z = tfc .ContinuousBatchedEntropyModel (
321
- self .hyperprior , coding_rank = 3 , compression = True )
322
+ self .hyperprior , coding_rank = 3 , compression = True ,
323
+ offset_heuristic = False )
322
324
self .em_y = tfc .LocationScaleIndexedEntropyModel (
323
325
tfc .NoisyNormal , num_scales = self .num_scales , scale_fn = self .scale_fn ,
324
326
coding_rank = 3 , compression = True )
0 commit comments