61
61
"HyperInfo" ,
62
62
"decoded latent_shape hyper_latent_shape "
63
63
"nbpp side_nbpp total_nbpp qbpp side_qbpp total_qbpp "
64
- "bitstring side_bitstring " ,
64
+ "bitstream_tensors " ,
65
65
)
66
66
67
67
@@ -86,7 +86,7 @@ def __init__(self,
86
86
model = [
87
87
tf .keras .layers .Conv2D (
88
88
filters = num_filters_base , kernel_size = 7 , padding = "same" ),
89
- LayerNorm (),
89
+ ChannelNorm (),
90
90
tf .keras .layers .ReLU ()
91
91
]
92
92
@@ -95,7 +95,7 @@ def __init__(self,
95
95
tf .keras .layers .Conv2D (
96
96
filters = num_filters_base * 2 ** (i + 1 ),
97
97
kernel_size = 3 , padding = "same" , strides = 2 ),
98
- LayerNorm (),
98
+ ChannelNorm (),
99
99
tf .keras .layers .ReLU ()])
100
100
101
101
model .append (
@@ -127,11 +127,11 @@ def __init__(self,
127
127
num_filters_base: base number of filters.
128
128
num_residual_blocks: number of residual blocks.
129
129
"""
130
- head = [LayerNorm (),
130
+ head = [ChannelNorm (),
131
131
tf .keras .layers .Conv2D (
132
132
filters = num_filters_base * (2 ** num_up ),
133
133
kernel_size = 3 , padding = "same" ),
134
- LayerNorm ()]
134
+ ChannelNorm ()]
135
135
136
136
residual_blocks = []
137
137
for block_idx in range (num_residual_blocks ):
@@ -151,7 +151,7 @@ def __init__(self,
151
151
filters = filters ,
152
152
kernel_size = 3 , padding = "same" ,
153
153
strides = 2 ),
154
- LayerNorm (),
154
+ ChannelNorm (),
155
155
tf .keras .layers .ReLU ()]
156
156
157
157
# Final conv layer.
@@ -201,19 +201,19 @@ def __init__(
201
201
202
202
block = [
203
203
tf .keras .layers .Conv2D (** kwargs_conv2d ),
204
- LayerNorm (),
204
+ ChannelNorm (),
205
205
tf .keras .layers .Activation (activation ),
206
206
tf .keras .layers .Conv2D (** kwargs_conv2d ),
207
- LayerNorm ()]
207
+ ChannelNorm ()]
208
208
209
209
self .block = tf .keras .Sequential (name = name , layers = block )
210
210
211
211
def call (self , inputs , ** kwargs ):
212
212
return inputs + self .block (inputs , ** kwargs )
213
213
214
214
215
- class LayerNorm (tf .keras .layers .Layer ):
216
- """Implement LayerNorm .
215
+ class ChannelNorm (tf .keras .layers .Layer ):
216
+ """Implement ChannelNorm .
217
217
218
218
Based on this paper and keras' InstanceNorm layer:
219
219
Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton.
@@ -238,7 +238,7 @@ def __init__(self,
238
238
gamma_initializer: Initializer for gamma.
239
239
**kwargs: Passed to keras.
240
240
"""
241
- super (LayerNorm , self ).__init__ (** kwargs )
241
+ super (ChannelNorm , self ).__init__ (** kwargs )
242
242
243
243
self .axis = - 1
244
244
self .epsilon = epsilon
@@ -478,6 +478,14 @@ def _make_synthesis(syn_name):
478
478
479
479
self ._side_entropy_model = FactorizedPriorLayer ()
480
480
481
+ @property
482
+ def losses (self ):
483
+ return self ._side_entropy_model .losses
484
+
485
+ @property
486
+ def updates (self ):
487
+ return self ._side_entropy_model .updates
488
+
481
489
@property
482
490
def transform_layers (self ):
483
491
return [self ._analysis , self ._synthesis_scale , self ._synthesis_mean ]
@@ -529,7 +537,7 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
529
537
530
538
compressed = None
531
539
if training :
532
- latents_decoded = _quantize (latents , latent_means )
540
+ latents_decoded = _ste_quantize (latents , latent_means )
533
541
elif validation :
534
542
latents_decoded = entropy_info .quantized
535
543
else :
@@ -546,16 +554,25 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
546
554
qbpp = entropy_info .qbpp ,
547
555
side_qbpp = side_info .total_qbpp ,
548
556
total_qbpp = entropy_info .qbpp + side_info .total_qbpp ,
549
- bitstring = compressed ,
550
- side_bitstring = side_info .bitstring )
557
+ # We put everything that's needed for real arithmetic coding into
558
+ # the bistream_tensors tuple.
559
+ bitstream_tensors = (compressed , side_info .bitstring ,
560
+ image_shape , latent_shape , side_info .latent_shape ))
551
561
552
562
tf .summary .scalar ("bpp/total/noisy" , info .total_nbpp )
553
563
tf .summary .scalar ("bpp/total/quantized" , info .total_qbpp )
554
564
565
+ tf .summary .scalar ("bpp/latent/noisy" , entropy_info .nbpp )
566
+ tf .summary .scalar ("bpp/latent/quantized" , entropy_info .qbpp )
567
+
568
+ tf .summary .scalar ("bpp/side/noisy" , side_info .total_nbpp )
569
+ tf .summary .scalar ("bpp/side/quantized" , side_info .total_qbpp )
570
+
555
571
return info
556
572
557
573
558
- def _quantize (inputs , mean ):
574
+ def _ste_quantize (inputs , mean ):
575
+ """Calculates quantize(inputs - mean) + mean, sets straight-through grads."""
559
576
half = tf .constant (.5 , dtype = tf .float32 )
560
577
outputs = inputs
561
578
outputs -= mean
0 commit comments