Skip to content

Commit 312df2b

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds toy models to TFC.
This code can be used to reproduce the results on toy distributions from the following papers: Ballé et al., "Nonlinear Transform Coding" https://arxiv.org/abs/2007.03034 Wagner & Ballé, "Neural Networks Optimally Compress the Sawbridge" https://arxiv.org/abs/2011.05065 PiperOrigin-RevId: 349201435 Change-Id: Ie001876208333d5063f2c1da7c0ab08a6c11b491
1 parent 0ac9824 commit 312df2b

File tree

5 files changed

+1886
-0
lines changed

5 files changed

+1886
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Base class for coding experiment."""
2+
3+
import abc
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import tensorflow.compat.v2 as tf
7+
import tensorflow_probability as tfp
8+
9+
10+
class CompressionModel(tf.keras.Model, metaclass=abc.ABCMeta):
11+
"""Base class for coding experiment."""
12+
13+
def __init__(self, source, lmbda, distortion_loss, **kwargs):
14+
super().__init__(**kwargs)
15+
self.source = source
16+
self.lmbda = float(lmbda)
17+
self.distortion_loss = str(distortion_loss)
18+
19+
@property
20+
def ndim_source(self):
21+
return self.source.event_shape[0]
22+
23+
@abc.abstractmethod
24+
def quantize(self, x):
25+
"""Determines an equivalent vector quantizer for `x`.
26+
27+
Arguments:
28+
x: A batch of source vectors.
29+
30+
Returns:
31+
codebook: A codebook of vectors used to represent all elements of `x`.
32+
rates: For each codebook vector, the self-information in bits needed to
33+
encode it.
34+
indexes: Integer `Tensor`. For each batch element in `x`, returns the
35+
index into `codebook` that it is represented as.
36+
"""
37+
38+
@abc.abstractmethod
39+
def train_losses(self, x):
40+
"""Computes the training losses for `x`.
41+
42+
Arguments:
43+
x: A batch of source vectors.
44+
45+
Returns:
46+
Either an RD loss value for each element in `x`, or a tuple which contains
47+
the rate and distortion losses for each element separately (as in
48+
`test_losses`).
49+
"""
50+
51+
@abc.abstractmethod
52+
def test_losses(self, x):
53+
"""Computes the rate and distortion for each element of `x`.
54+
55+
Arguments:
56+
x: A batch of source vectors.
57+
58+
Returns:
59+
rates: For each element in `x`, the self-information in bits needed to
60+
encode it.
61+
distortions: The distortion loss for each element in `x`.
62+
"""
63+
64+
def distortion_fn(self, reference, reconstruction):
65+
reference = tf.cast(reference, self.dtype)
66+
if self.distortion_loss == "sse":
67+
diff = tf.math.squared_difference(reference, reconstruction)
68+
return tf.math.reduce_sum(diff, axis=-1)
69+
if self.distortion_loss == "mse":
70+
diff = tf.math.squared_difference(reference, reconstruction)
71+
return tf.math.reduce_mean(diff, axis=-1)
72+
73+
def train_step(self, x):
74+
if hasattr(self, "alpha"):
75+
self.alpha = self.force_alpha
76+
with tf.GradientTape() as tape:
77+
rates, distortions = self.train_losses(x)
78+
losses = rates + self.lmbda * distortions
79+
loss = tf.math.reduce_mean(losses)
80+
variables = self.trainable_variables
81+
gradients = tape.gradient(loss, variables)
82+
self.optimizer.apply_gradients(zip(gradients, variables))
83+
self.loss.update_state(losses)
84+
self.rate.update_state(rates)
85+
self.distortion.update_state(distortions)
86+
energy = []
87+
size = []
88+
for grad in gradients:
89+
if grad is None:
90+
continue
91+
energy.append(tf.reduce_sum(tf.square(tf.cast(grad, tf.float64))))
92+
size.append(tf.cast(tf.size(grad), tf.float64))
93+
self.grad_rms.update_state(tf.sqrt(tf.add_n(energy) / tf.add_n(size)))
94+
return {
95+
m.name: m.result()
96+
for m in [self.loss, self.rate, self.distortion, self.grad_rms]
97+
}
98+
99+
def test_step(self, x):
100+
rates, distortions = self.test_losses(x)
101+
losses = rates + self.lmbda * distortions
102+
self.loss.update_state(losses)
103+
self.rate.update_state(rates)
104+
self.distortion.update_state(distortions)
105+
return {m.name: m.result() for m in [self.loss, self.rate, self.distortion]}
106+
107+
def compile(self, **kwargs):
108+
super().compile(
109+
loss=None,
110+
metrics=None,
111+
loss_weights=None,
112+
weighted_metrics=None,
113+
**kwargs,
114+
)
115+
self.loss = tf.keras.metrics.Mean(name="loss")
116+
self.rate = tf.keras.metrics.Mean(name="rate")
117+
self.distortion = tf.keras.metrics.Mean(name="distortion")
118+
self.grad_rms = tf.keras.metrics.Mean(name="gradient RMS")
119+
120+
def fit(self, batch_size, validation_size, validation_batch_size, **kwargs):
121+
train_data = tf.data.Dataset.from_tensors([])
122+
train_data = train_data.repeat()
123+
train_data = train_data.map(
124+
lambda _: self.source.sample(batch_size),
125+
)
126+
127+
seed = tfp.util.SeedStream(528374623, "compression_model_fit")
128+
# This rounds up to multiple of batch size.
129+
validation_batches = (validation_size - 1) // validation_batch_size + 1
130+
validation_data = tf.data.Dataset.from_tensors([])
131+
validation_data = validation_data.repeat(validation_batches)
132+
validation_data = validation_data.map(
133+
lambda _: self.source.sample(validation_batch_size, seed=seed),
134+
)
135+
136+
super().fit(
137+
train_data,
138+
validation_data=validation_data,
139+
shuffle=False,
140+
**kwargs,
141+
)
142+
143+
def plot_quantization(self, intervals, figsize=None, **kwargs):
144+
if len(intervals) != self.ndim_source or self.ndim_source not in (1, 2):
145+
raise ValueError("This method is only defined for 1D or 2D models.")
146+
147+
data = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
148+
data = tf.meshgrid(*data, indexing="ij")
149+
data = tf.stack(data, axis=-1)
150+
151+
codebook, rates, indexes = self.quantize(data, **kwargs)
152+
codebook = codebook.numpy()
153+
rates = rates.numpy()
154+
indexes = indexes.numpy()
155+
156+
data_dist = self.source.prob(data).numpy()
157+
counts = np.bincount(np.ravel(indexes), minlength=len(codebook))
158+
prior = 2 ** (-rates)
159+
160+
if self.ndim_source == 1:
161+
data = np.squeeze(data, axis=-1)
162+
boundaries = np.nonzero(indexes[1:] != indexes[:-1])[0]
163+
boundaries = (data[boundaries] + data[boundaries + 1]) / 2
164+
plt.figure(figsize=figsize or (16, 8))
165+
plt.plot(data, data_dist, label="source")
166+
markers, stems, base = plt.stem(
167+
codebook[counts > 0], prior[counts > 0], label="codebook")
168+
plt.setp(markers, color="black")
169+
plt.setp(stems, color="black")
170+
plt.setp(base, linestyle="None")
171+
plt.xticks(np.sort(codebook[counts > 0]))
172+
plt.grid(False, axis="x")
173+
for r in boundaries:
174+
plt.axvline(
175+
r, color="black", lw=1, ls=":",
176+
label="boundaries" if r == boundaries[0] else None)
177+
plt.xlim(np.min(data), np.max(data))
178+
plt.ylim(bottom=-.01)
179+
plt.legend(loc="upper left")
180+
plt.xlabel("source space")
181+
else:
182+
google_pink = (0xf4/255, 0x39/255, 0xa0/255)
183+
plt.figure(figsize=figsize or (16, 14))
184+
vmax = data_dist.max()
185+
plt.imshow(
186+
data_dist, vmin=0, vmax=vmax, origin="lower",
187+
extent=(
188+
data[0, 0, 1], data[0, -1, 1], data[0, 0, 0], data[-1, 0, 0]))
189+
plt.contour(
190+
data[:, :, 1], data[:, :, 0], indexes,
191+
np.arange(len(codebook)) + .5,
192+
colors=[google_pink], linewidths=.5)
193+
plt.plot(
194+
codebook[counts > 0, 1], codebook[counts > 0, 0],
195+
"o", color=google_pink)
196+
plt.axis("image")
197+
plt.grid(False)
198+
plt.xlim(data[0, 0, 1], data[0, -1, 1])
199+
plt.ylim(data[0, 0, 0], data[-1, 0, 0])
200+
plt.xlabel("source dimension 1")
201+
plt.ylabel("source dimension 2")

0 commit comments

Comments
 (0)