Skip to content

Commit 9246c4e

Browse files
Johannes Ball?copybara-github
authored andcommitted
Adds example training code for BMSHJ2018 model.
Also adds some minor tweaks to other example scripts. PiperOrigin-RevId: 259822670 Change-Id: Iaf726302909190673eba2c174fa7c4989f7f5d3e
1 parent 47cff11 commit 9246c4e

File tree

4 files changed

+596
-6
lines changed

4 files changed

+596
-6
lines changed

BUILD

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ py_binary(
3939
],
4040
)
4141

42+
py_binary(
43+
name = "generate_docs",
44+
srcs = ["tools/generate_docs.py"],
45+
deps = [":tensorflow_compression"],
46+
)
47+
4248
py_binary(
4349
name = "tfci",
4450
srcs = ["examples/tfci.py"],
@@ -52,7 +58,7 @@ py_binary(
5258
)
5359

5460
py_binary(
55-
name = "generate_docs",
56-
srcs = ["tools/generate_docs.py"],
61+
name = "bmshj2018",
62+
srcs = ["examples/bmshj2018.py"],
5763
deps = [":tensorflow_compression"],
5864
)

examples/bls2017.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
https://arxiv.org/abs/1611.01704
2323
2424
With patches from Victor Xing <[email protected]>
25+
26+
This is meant as 'educational' code - you can use this to get started with your
27+
own experiments. To reproduce the exact results from the paper, tuning of hyper-
28+
parameters may be necessary. To compress images with published models, see
29+
`tfci.py`.
2530
"""
2631

2732
from __future__ import absolute_import
@@ -176,11 +181,11 @@ def build(self, input_shape):
176181
tfc.SignalConv2D(
177182
self.num_filters, (5, 5), name="layer_0", corr=False, strides_up=2,
178183
padding="same_zeros", use_bias=True,
179-
activation=tfc.GDN(name="gdn_0", inverse=True)),
184+
activation=tfc.GDN(name="igdn_0", inverse=True)),
180185
tfc.SignalConv2D(
181186
self.num_filters, (5, 5), name="layer_1", corr=False, strides_up=2,
182187
padding="same_zeros", use_bias=True,
183-
activation=tfc.GDN(name="gdn_1", inverse=True)),
188+
activation=tfc.GDN(name="igdn_1", inverse=True)),
184189
tfc.SignalConv2D(
185190
3, (9, 9), name="layer_2", corr=False, strides_up=4,
186191
padding="same_zeros", use_bias=True,
@@ -290,7 +295,7 @@ def compress(args):
290295
y_hat, likelihoods = entropy_bottleneck(y, training=False)
291296
x_hat = synthesis_transform(y_hat)
292297

293-
num_pixels = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1]))
298+
num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32)
294299

295300
# Total number of bits divided by number of pixels.
296301
eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)

0 commit comments

Comments
 (0)