Skip to content

Commit 52af1e0

Browse files
gtodericicopybara-github
authored andcommitted
Various fixes to HiFiC.
PiperOrigin-RevId: 318389553 Change-Id: I422136163a9211cd5cb6f463c5568ab44fb42e46
1 parent 8054e6a commit 52af1e0

File tree

8 files changed

+308
-108
lines changed

8 files changed

+308
-108
lines changed

models/hific/README.md

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# High-Fidelity Generative Image Compression
22

3+
## PRE-RELEASE
4+
35
<div align="center">
46
<a href='https://hific.github.io'>
57
<img src='https://hific.github.io/social/thumb.jpg' width="80%"/>
@@ -30,71 +32,131 @@ use more than 2&times; the bitrate.
3032
We show some images on the [demo page](https://hific.github.io) and we
3133
release a
3234
[colab](https://colab.research.google.com/github/tensorflow/compression/blob/master/models/hific/colab.ipynb)
33-
update for interactively using our models on your own images.
35+
for interactively using our models on your own images.
36+
37+
## Running models trained by us locally
38+
39+
Use `tfci.py` for locally running our models to encode and decode images:
40+
41+
```bash
42+
git clone https://github.com/tensorflow/compression
43+
cd compression/compression/models
44+
python tfci.py compress <model> <PNG file>
45+
```
46+
47+
where `model` can be one of `"hific-lo", "hific-mi", "hific-hi"`.
48+
49+
**NOTE**: This is also directly available in the
50+
[colab](https://colab.research.google.com/github/tensorflow/compression/blob/master/models/hific/colab.ipynb)!
3451

3552
## Using the code
3653

37-
In addition to `tensorflow_compression`, you need to install [`compare_gan`](https://github.com/google/compare_gan)
38-
and TensorFlow 1.15:
54+
55+
To use the code, create a conda environment using Python 3.6
56+
(newer is not supported at the moment), and the following packages.
57+
58+
**NOTE**: We only support CUDA 10.0, Python 3.6, and TensorFlow 1.15.
59+
TensorFlow must be installed via pip, not conda.
60+
Any other setup is not going to work (we tested newer versions of Tensorflow
61+
and Python and they don't work). We're working on a fix.
3962

4063
```bash
41-
pip install -r requirements.txt
64+
conda create --name hific python=3.6 cudatoolkit=10.0 cudnn
65+
conda activate hific
66+
pip install tensorflow-gpu==1.15 # Make sure to install TF via pip, not conda!
67+
pip install git+git://github.com/google/compare_gan@19922d3004b675c1a49c4d7515c06f6f75acdcc8
68+
pip install tensorflow-compression==1.3
69+
pip install Pillow
4270
```
4371

44-
## Running our models locally
72+
#### Note on CUDNN Errors
4573

46-
Use `tfci.py` for locally running our models to encode and decode images:
74+
On some of our test machines, the code crashes with one of "Could not create
75+
cudnn handle: CUDNN_STATUS_INTERNAL_ERROR", "terminate called after throwing an
76+
instance of 'std::bad_alloc'", "Segmentation fault", "Unknown: Failed to get
77+
convolution algorithm. This is probably because cuDNN".
4778

48-
```python
49-
python tfci.py compress <model> <PNG file>
79+
In this case, try setting `TF_FORCE_GPU_ALLOW_GROWTH=true`, e.g.:
80+
```bash
81+
TF_FORCE_GPU_ALLOW_GROWTH=true python train.py ...
5082
```
5183

52-
where `model` can be one of `"hific-lo", "hific-mi", "hific-hi"`.
84+
#### Note on Memory Consumption
5385

54-
## Code
86+
This model trains best on a V100. If you get out-of-memory errors
87+
("Resource exhausted: OOM"), try lowering the batch size
88+
(e.g., `--batch_size 6`), or tweak `num_residual_blocks` in `archs.py/Decoder`.
89+
90+
If you get slow training/stalling, try tweaking the `DATASET_NUM_PARALLEL` and
91+
`DATASET_PREFETCH` constants in `model.py`.
5592

56-
The architecture is defined in `arch.py` , which is used to build the model in
57-
`model.py`. Our configurations are in `configs.py`.
5893

5994
### Training your own models.
6095

96+
The architecture is defined in `arch.py`, which is used to build the model from
97+
`model.py`. Our configurations are in `configs.py`.
98+
99+
61100
We release a _simplified_ trainer in `train.py` as a starting point for custom
62-
training. Note that it's using [LSUN]() from [tfds]() which likely needs to be
63-
adapted to a bigger dataset to obtain state-of-the-art results (see below).
101+
training. Note that it's using
102+
[coco2014](https://cocodataset.org) from
103+
[tfds](https://www.tensorflow.org/datasets/api_docs/python/tfds) which likely
104+
needs to be adapted to a bigger dataset to obtain good results
105+
(see below).
64106

65107
For the paper, we initialize our GAN models from a MSE+LPIPS checkpoint. To
66108
replicate this, first train a model for MSE + LPIPS only, and then use that as a
67109
starting point:
68-
69110
```bash
70111
# First train a model for MSE+LPIPS:
71-
python train.py --config mselpips --ckpt_dir ckpts --num_steps 1M
112+
python train.py --config mselpips --ckpt_dir ckpts/mse_lpips --num_steps 1M
113+
--tfds_dataset_name coco2014
72114

73115
# Once that finishes, train a GAN model:
74-
python train.py --config hific --ckpt_dir ckpts \
75-
--init_from ckpts/mselpips --num_steps 1M
116+
python train.py --config hific --ckpt_dir ckpts/hific \
117+
--init_autoencoder_from_ckpt_dir ckpts/mselpips --num_steps 1M
118+
--tfds_dataset_name coco2014
76119
```
120+
Additional helpful arguments are `--tfds_dataset_name`,
121+
and `--tfds_download_dir`, see `--help` for more.
77122

78-
To test a trained model, use `eval.py`:
123+
Note that TensorBoard summaries will be saved in `--ckpts` as well. By default,
124+
we create summaries of inputs and reconstructions, which can use a lot of
125+
memory. Disable with `--no-image-summaries`.
126+
127+
To test a trained model, use `evaluate.py` (it also supports the `--tfds_*`
128+
flags):
79129

80130
```bash
81-
python eval.py --config hific --ckpt_dir ckpts/hific
131+
python evaluate.py --config hific --ckpt_dir ckpts/hific --out_dir out/ \
132+
--tfds_dataset_name coco2014
82133
```
83134

84135
#### Adapting the dataset
85136

86-
You can change to any other TFDS dataset by changing the `tfds_name` flag for
87-
`build_input`. To train on a custom dataset, you can replace the `_get_dataset`
137+
You can change to any other TFDS dataset by adapting the `--tfds_dataset_name`,
138+
`--tfds_feature_key`, and `--tfds_download_dir` flags of `train.py`.
139+
140+
Note that when using TFDS, the dataset first has to be downloaded, which can
141+
take time. To do this separately, use the following code snippet:
142+
```python
143+
import tensorflow_datasets as tfds
144+
builder = tfds.builder(TFDS_DATASET_NAME, data_dir=TFDS_DOWNLOAD_DIR)
145+
builder.download_and_prepare()
146+
```
147+
148+
To train on a custom dataset, you can replace the `_get_dataset`
88149
call in `train.py`.
89150

90151
## Citation
91152

92153
If you use the work released here for your research, please cite this paper:
93154

94155
```
95-
@inproceedings{mentzer2020hific,
156+
@article{mentzer2020high,
96157
title={High-Fidelity Generative Image Compression},
97-
author={Fabian Mentzer and George Toderici and Michael Tschannen and Eirikur Agustsson},
158+
author={Mentzer, Fabian and Toderici, George and Tschannen, Michael and Agustsson, Eirikur},
159+
journal={arXiv preprint arXiv:2006.09965},
98160
year={2020}
99161
}
100162
```

models/hific/archs.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from .helpers import ModelMode
3232

33-
3433
SCALES_MIN = 0.11
3534
SCALES_MAX = 256
3635
SCALES_LEVELS = 64
@@ -327,7 +326,13 @@ def __init__(self,
327326
self._num_layers = num_layers
328327
self._num_filters_base = num_filters_base
329328

329+
def __call__(self, x):
330+
"""Overwriting compare_gan's __call__ as we only need `x`."""
331+
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
332+
return self.apply(x)
333+
330334
def apply(self, x):
335+
"""Overwriting compare_gan's apply as we only need `x`."""
331336
if not isinstance(x, tuple) or len(x) != 2:
332337
raise ValueError("Expected 2-tuple, got {}".format(x))
333338
x, latent = x
@@ -398,18 +403,8 @@ def __init__(self,
398403
self._name = name
399404
self._model = compare_gan_cls(**compare_gan_kwargs)
400405

401-
def call(self, x, training):
402-
# compare_gan code distinguishes between training and evaluation using
403-
# two entry points: __call__ for training, inference for evaluation.
404-
# Depending on the mode, different norm modes are set.
405-
# We switch depending on the training flag.
406-
if training:
407-
return self._model(x)
408-
else:
409-
return self._model.inference(x)
410-
411-
def apply_directly(self, x):
412-
return self._model.apply(x)
406+
def call(self, x):
407+
return self._model(x)
413408

414409
@property
415410
def trainable_variables(self):
@@ -418,7 +413,7 @@ def trainable_variables(self):
418413
# don't have training as a flag to the constructor, so we always return.
419414
# However, we only call trainable_variables when we are training.
420415
return tf.get_collection(
421-
tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._name)
416+
tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._model.name)
422417

423418

424419
class Discriminator(_CompareGANLayer):
@@ -433,13 +428,11 @@ class Hyperprior(tf.keras.layers.Layer):
433428
"""Hyperprior architecture (probability model)."""
434429

435430
def __init__(self,
436-
round_latents_for_training=True,
437431
num_chan_bottleneck=220,
438432
num_filters=320,
439433
name="Hyperprior"):
440434
super(Hyperprior, self).__init__(name=name)
441435

442-
self._round_latents_for_training = round_latents_for_training
443436
self._num_chan_bottleneck = num_chan_bottleneck
444437
self._num_filters = num_filters
445438
self._analysis = tf.keras.Sequential([
@@ -537,9 +530,7 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
537530

538531
compressed = None
539532
if training:
540-
latents_decoded = (entropy_info.quantized
541-
if self._round_latents_for_training else
542-
entropy_info.noisy)
533+
latents_decoded = _quantize(latents, latent_means)
543534
elif validation:
544535
latents_decoded = entropy_info.quantized
545536
else:
@@ -565,6 +556,16 @@ def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo:
565556
return info
566557

567558

559+
def _quantize(inputs, mean):
560+
half = tf.constant(.5, dtype=tf.float32)
561+
outputs = inputs
562+
outputs -= mean
563+
# Rounding latents for the forward pass (straight-through).
564+
outputs = outputs + tf.stop_gradient(tf.math.floor(outputs + half) - outputs)
565+
outputs += mean
566+
return outputs
567+
568+
568569
class FactorizedPriorLayer(tf.keras.layers.Layer):
569570
"""Factorized prior to code a discrete tensor."""
570571

models/hific/evaluate.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020

2121
import argparse
2222
import itertools
23-
24-
from absl import app
23+
import os
24+
import sys
25+
from PIL import Image
2526

2627
import tensorflow.compat.v1 as tf
2728

@@ -30,27 +31,41 @@
3031
from . import model
3132

3233

33-
def eval_trained_model(config_name, ckpt_dir, max_images=None):
34+
def eval_trained_model(config_name,
35+
ckpt_dir,
36+
out_dir,
37+
tfds_arguments: helpers.TFDSArguments,
38+
max_images=None):
3439
"""Evaluate a trained model."""
3540
config = configs.get_config(config_name)
3641
hific = model.HiFiC(config, helpers.ModelMode.EVALUATION)
3742

38-
# Automatically uses the validation split of LSUN.
39-
dataset = hific.build_input(batch_size=1, crop_size=None, tfds_name='lsun')
43+
# Automatically uses the validation split.
44+
dataset = hific.build_input(
45+
batch_size=1, crop_size=None, tfds_arguments=tfds_arguments)
4046
iterator = tf.data.make_one_shot_iterator(dataset)
4147
get_next_image = iterator.get_next()
4248

4349
output_image, bpp = hific.build_model(**get_next_image)
4450
input_image = get_next_image['input_image']
51+
52+
input_image = tf.cast(tf.round(input_image[0, ...]), tf.uint8)
53+
output_image = tf.cast(tf.round(output_image[0, ...]), tf.uint8)
54+
55+
os.makedirs(out_dir, exist_ok=True)
56+
4557
with tf.Session() as sess:
4658
hific.restore_trained_model(sess, ckpt_dir)
4759
for i in itertools.count(0):
4860
if max_images and i == max_images:
4961
break
5062
try:
51-
input_, output_, bpp_ = sess.run([input_image, output_image, bpp])
52-
# TODO(fab-jul): Save image, report bpp, etc.
53-
print(input_.shape, output_.shape, bpp_)
63+
inp_np, otp_np, bpp_np = sess.run([input_image, output_image, bpp])
64+
print(f'Image {i}: {bpp_np:.3} bpp, saving in {out_dir}...')
65+
Image.fromarray(inp_np).save(
66+
os.path.join(out_dir, f'img_{i:010d}inp.png'))
67+
Image.fromarray(otp_np).save(
68+
os.path.join(out_dir, f'img_{i:010d}otp_{bpp_np:.3f}.png'))
5469
except tf.errors.OutOfRangeError:
5570
print('No more inputs')
5671
break
@@ -67,13 +82,18 @@ def parse_args(argv):
6782
parser.add_argument('--ckpt_dir', required=True,
6883
help=('Path to the folder where checkpoints of the '
6984
'trained model are.'))
85+
parser.add_argument('--out_dir', required=True, help='Where to save outputs.')
86+
87+
helpers.add_tfds_arguments(parser)
88+
7089
args = parser.parse_args(argv[1:])
7190
return args
7291

7392

7493
def main(args):
75-
eval_trained_model(args.config, args.ckpt_dir)
94+
eval_trained_model(args.config, args.ckpt_dir, args.out_dir,
95+
helpers.parse_tfds_arguments(args))
7696

7797

7898
if __name__ == '__main__':
79-
app.run(main, flags_parser=parse_args)
99+
main(parse_args(sys.argv))

models/hific/helpers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Some helper enums and classes, as well as LPIPS downloader."""
1717

1818

19+
import collections
1920
import enum
2021
import os
2122
import pprint
@@ -24,6 +25,9 @@
2425

2526
_LPIPS_URL = "http://rail.eecs.berkeley.edu/models/lpips/net-lin_alex_v0.1.pb"
2627

28+
TFDSArguments = collections.namedtuple(
29+
"TFDSArguments", ["dataset_name", "features_key", "downloads_dir"])
30+
2731

2832
class ModelType(enum.Enum):
2933
# Train hyperprior model: encoder/decoder/entropy model.
@@ -66,3 +70,22 @@ def ensure_lpips_weights_exist(weight_path_out):
6670
if not os.path.isfile(weight_path_out):
6771
raise ValueError(f"Failed to download LPIPS weights from {_LPIPS_URL} "
6872
f"to {weight_path_out}. Please manually download!")
73+
74+
75+
def add_tfds_arguments(parser):
76+
parser.add_argument(
77+
"--tfds_dataset_name", default="coco2014", help="TFDS dataset name.")
78+
parser.add_argument(
79+
"--tfds_downloads_dir",
80+
default=None,
81+
help=("Where TFDS stores data."
82+
"Defaults to ~/tensorflow_datasets."))
83+
parser.add_argument(
84+
"--tfds_features_key",
85+
default="image",
86+
help="Name of the TFDS feature to use.")
87+
88+
89+
def parse_tfds_arguments(args) -> TFDSArguments:
90+
return TFDSArguments(args.tfds_dataset_name, args.tfds_features_key,
91+
args.tfds_downloads_dir)

0 commit comments

Comments
 (0)