Skip to content

Commit e5872ee

Browse files
relationalcopybara-github
authored andcommitted
Import of pull request #58: HiFiC for Camera Ready (author: fab-jul).
- Remove "PRE-RELEASE" warning. - Improve evaluate.py: - Do real arithmetic coding to report bitrate. This includes running the entropy model updates. - Support image folders as input (not just TFDS). - Report PSNR. - Rename LayerNorm to ChannelNorm. - Fix Colab to make it better to use: - Support different HiFiC models with drop down. - Support blindly running whole file. - Support skipping the upload cell. - support re-running with same files but potentially different model. - Fix tf.logging statements not getting printed. - Fix train.py doing back to back training. PiperOrigin-RevId: 342352674 Change-Id: If3a78a5c37cf827aa0ed71a990c178144084b766
1 parent 01ff785 commit e5872ee

File tree

7 files changed

+236
-110
lines changed

7 files changed

+236
-110
lines changed

models/hific/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# High-Fidelity Generative Image Compression
22

3-
## PRE-RELEASE
4-
53
<div align="center">
64
<a href='https://hific.github.io'>
75
<img src='https://hific.github.io/social/thumb.jpg' width="80%"/>
@@ -96,7 +94,6 @@ If you get slow training/stalling, try tweaking the `DATASET_NUM_PARALLEL` and
9694
The architecture is defined in `arch.py`, which is used to build the model from
9795
`model.py`. Our configurations are in `configs.py`.
9896

99-
10097
We release a _simplified_ trainer in `train.py` as a starting point for custom
10198
training. Note that it's using
10299
[coco2014](https://cocodataset.org) from

models/hific/archs.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"HyperInfo",
6262
"decoded latent_shape hyper_latent_shape "
6363
"nbpp side_nbpp total_nbpp qbpp side_qbpp total_qbpp "
64-
"bitstring side_bitstring",
64+
"bitstream_tensors",
6565
)
6666

6767

@@ -86,7 +86,7 @@ def __init__(self,
8686
model = [
8787
tf.keras.layers.Conv2D(
8888
filters=num_filters_base, kernel_size=7, padding="same"),
89-
LayerNorm(),
89+
ChannelNorm(),
9090
tf.keras.layers.ReLU()
9191
]
9292

@@ -95,7 +95,7 @@ def __init__(self,
9595
tf.keras.layers.Conv2D(
9696
filters=num_filters_base * 2 ** (i + 1),
9797
kernel_size=3, padding="same", strides=2),
98-
LayerNorm(),
98+
ChannelNorm(),
9999
tf.keras.layers.ReLU()])
100100

101101
model.append(
@@ -127,11 +127,11 @@ def __init__(self,
127127
num_filters_base: base number of filters.
128128
num_residual_blocks: number of residual blocks.
129129
"""
130-
head = [LayerNorm(),
130+
head = [ChannelNorm(),
131131
tf.keras.layers.Conv2D(
132132
filters=num_filters_base * (2 ** num_up),
133133
kernel_size=3, padding="same"),
134-
LayerNorm()]
134+
ChannelNorm()]
135135

136136
residual_blocks = []
137137
for block_idx in range(num_residual_blocks):
@@ -151,7 +151,7 @@ def __init__(self,
151151
filters=filters,
152152
kernel_size=3, padding="same",
153153
strides=2),
154-
LayerNorm(),
154+
ChannelNorm(),
155155
tf.keras.layers.ReLU()]
156156

157157
# Final conv layer.
@@ -201,19 +201,19 @@ def __init__(
201201

202202
block = [
203203
tf.keras.layers.Conv2D(**kwargs_conv2d),
204-
LayerNorm(),
204+
ChannelNorm(),
205205
tf.keras.layers.Activation(activation),
206206
tf.keras.layers.Conv2D(**kwargs_conv2d),
207-
LayerNorm()]
207+
ChannelNorm()]
208208

209209
self.block = tf.keras.Sequential(name=name, layers=block)
210210

211211
def call(self, inputs, **kwargs):
212212
return inputs + self.block(inputs, **kwargs)
213213

214214

215-
class LayerNorm(tf.keras.layers.Layer):
216-
"""Implement LayerNorm.
215+
class ChannelNorm(tf.keras.layers.Layer):
216+
"""Implement ChannelNorm.
217217
218218
Based on this paper and keras' InstanceNorm layer:
219219
Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton.
@@ -238,7 +238,7 @@ def __init__(self,
238238
gamma_initializer: Initializer for gamma.
239239
**kwargs: Passed to keras.
240240
"""
241-
super(LayerNorm, self).__init__(**kwargs)
241+
super(ChannelNorm, self).__init__(**kwargs)
242242

243243
self.axis = -1
244244
self.epsilon = epsilon
@@ -478,6 +478,14 @@ def _make_synthesis(syn_name):
478478

479479
self._side_entropy_model = FactorizedPriorLayer()
480480

481+
@property
482+
def losses(self):
483+
return self._side_entropy_model.losses
484+
485+
@property
486+
def updates(self):
487+
return self._side_entropy_model.updates
488+
481489
@property
482490
def transform_layers(self):
483491
return [self._analysis, self._synthesis_scale, self._synthesis_mean]
@@ -529,7 +537,7 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
529537

530538
compressed = None
531539
if training:
532-
latents_decoded = _quantize(latents, latent_means)
540+
latents_decoded = _ste_quantize(latents, latent_means)
533541
elif validation:
534542
latents_decoded = entropy_info.quantized
535543
else:
@@ -546,16 +554,25 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
546554
qbpp=entropy_info.qbpp,
547555
side_qbpp=side_info.total_qbpp,
548556
total_qbpp=entropy_info.qbpp + side_info.total_qbpp,
549-
bitstring=compressed,
550-
side_bitstring=side_info.bitstring)
557+
# We put everything that's needed for real arithmetic coding into
558+
# the bistream_tensors tuple.
559+
bitstream_tensors=(compressed, side_info.bitstring,
560+
image_shape, latent_shape, side_info.latent_shape))
551561

552562
tf.summary.scalar("bpp/total/noisy", info.total_nbpp)
553563
tf.summary.scalar("bpp/total/quantized", info.total_qbpp)
554564

565+
tf.summary.scalar("bpp/latent/noisy", entropy_info.nbpp)
566+
tf.summary.scalar("bpp/latent/quantized", entropy_info.qbpp)
567+
568+
tf.summary.scalar("bpp/side/noisy", side_info.total_nbpp)
569+
tf.summary.scalar("bpp/side/quantized", side_info.total_qbpp)
570+
555571
return info
556572

557573

558-
def _quantize(inputs, mean):
574+
def _ste_quantize(inputs, mean):
575+
"""Calculates quantize(inputs - mean) + mean, sets straight-through grads."""
559576
half = tf.constant(.5, dtype=tf.float32)
560577
outputs = inputs
561578
outputs -= mean

0 commit comments

Comments
 (0)