Skip to content

Commit 994ca84

Browse files
Johannes Ballécopybara-github
authored andcommitted
Refactors ntc.py.
PiperOrigin-RevId: 353149647 Change-Id: I3851907607b64d951a840d1bb9ca35904e38aa70
1 parent 5ef40db commit 994ca84

File tree

1 file changed

+52
-36
lines changed

1 file changed

+52
-36
lines changed

models/toy_sources/ntc.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,36 @@ class NTCModel(compression_model.CompressionModel):
2121
def __init__(self, analysis, synthesis, prior_type="deep",
2222
dither=(1, 1, 0, 0), soft_round=(1, 0), guess_offset=False,
2323
**kwargs):
24+
"""Initializer.
25+
26+
Args:
27+
analysis: A `Layer` object implementing the analysis transform.
28+
synthesis: A `Layer` object implementing the synthesis transform.
29+
prior_type: String. Either 'deep' for `DeepFactorized` prior, or
30+
'gsm/gmm/lsm/lmm-X' for Gaussian/Logistic Scale Mixture/Mixture Model
31+
with X components.
32+
dither: Sequence of 4 Booleans. Whether to use dither for: rate term
33+
during training, distortion term during training, rate term during
34+
testing, distortion term during testing, respectively.
35+
soft_round: Sequence of 2 Booleans. Whether to use soft rounding during
36+
training or testing, respectively.
37+
guess_offset: Boolean. When not using soft rounding, whether to use the
38+
mode centering heuristic to determine the quantization offset during
39+
testing.
40+
**kwargs: Other arguments passed through to `CompressionModel` class.
41+
"""
2442
super().__init__(**kwargs)
2543
self._analysis = analysis
2644
self._synthesis = synthesis
2745
self.prior_type = str(prior_type)
28-
# train_rate, train_dist, test_rate, test_dist
2946
self.dither = tuple(bool(i) for i in dither)
30-
# train, test
3147
self.soft_round = tuple(bool(i) for i in soft_round)
3248
self.guess_offset = bool(guess_offset)
3349

3450
if self.prior_type == "deep":
3551
self._prior = tfc.DeepFactorized(
3652
batch_shape=[self.ndim_latent], dtype=self.dtype)
37-
elif self.prior_type == "deep_uniform":
38-
self._prior = tfc.DeepFactorized(
39-
batch_shape=[self.ndim_latent], dtype=self.dtype)
40-
self.log_uniform_width = tf.Variable(
41-
0, "log_uniform_width", dtype=self.dtype)
42-
else:
43-
assert self.prior_type[:4] in ("gsm-", "gmm-", "lsm-", "lmm-")
53+
elif self.prior_type[:4] in ("gsm-", "gmm-", "lsm-", "lmm-"):
4454
components = int(self.prior_type[4:])
4555
shape = (self.ndim_latent, components)
4656
self.logits = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
@@ -50,6 +60,8 @@ def __init__(self, analysis, synthesis, prior_type="deep",
5060
self.loc = 0.
5161
else:
5262
self.loc = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
63+
else:
64+
raise ValueError(f"Unknown prior_type: '{prior_type}'.")
5365

5466
self._logit_alpha = tf.Variable(-3, dtype=self.dtype, name="logit_alpha")
5567
self._force_alpha = tf.Variable(
@@ -58,8 +70,7 @@ def __init__(self, analysis, synthesis, prior_type="deep",
5870
def prior(self, soft_round, scale=None, alpha=None, skip_noise=False):
5971
if self.prior_type == "deep":
6072
prior = self._prior
61-
else:
62-
assert self.prior_type[:4] in ("gsm-", "gmm-", "lsm-", "lmm-")
73+
elif self.prior_type[:4] in ("gsm-", "gmm-", "lsm-", "lmm-"):
6374
cls = tfpd.Normal if self.prior_type.startswith("g") else tfpd.Logistic
6475
prior = tfpd.MixtureSameFamily(
6576
mixture_distribution=tfpd.Categorical(logits=self.logits),
@@ -74,6 +85,34 @@ def prior(self, soft_round, scale=None, alpha=None, skip_noise=False):
7485
return prior
7586
return tfc.UniformNoiseAdapter(prior)
7687

88+
@property
89+
def ndim_latent(self):
90+
return self._analysis.output_shape[-1]
91+
92+
def analysis(self, x):
93+
y = tf.cast(x, self.dtype)
94+
if y.shape[-1] != self.ndim_source:
95+
raise ValueError(
96+
f"Expected {self.ndim_source} trailing dimensions, "
97+
f"received {y.shape[-1]}.")
98+
batch_shape = tf.shape(y)[:-1]
99+
y = tf.reshape(y, (-1, self.ndim_source))
100+
y = self._analysis(y)
101+
assert y.shape[-1] == self.ndim_latent
102+
return tf.reshape(y, tf.concat([batch_shape, [self.ndim_latent]], 0))
103+
104+
def synthesis(self, y):
105+
x = tf.cast(y, self.dtype)
106+
if x.shape[-1] != self.ndim_latent:
107+
raise ValueError(
108+
f"Expected {self.ndim_latent} trailing dimensions, "
109+
f"received {x.shape[-1]}.")
110+
batch_shape = tf.shape(x)[:-1]
111+
x = tf.reshape(x, (-1, self.ndim_latent))
112+
x = self._synthesis(x)
113+
assert x.shape[-1] == self.ndim_source
114+
return tf.reshape(x, tf.concat([batch_shape, [self.ndim_source]], 0))
115+
77116
@property
78117
def force_alpha(self):
79118
return tf.convert_to_tensor(self._force_alpha)
@@ -100,30 +139,6 @@ def get_logit_alpha():
100139
self._logit_alpha.assign(
101140
tf.cond(value < 0, lambda: self._logit_alpha, get_logit_alpha))
102141

103-
@property
104-
def ndim_latent(self):
105-
return self._analysis.output_shape[-1]
106-
107-
def analysis(self, x):
108-
y = tf.cast(x, self.dtype)
109-
assert y.shape[-1] == self.ndim_source
110-
batch_shape = tf.shape(y)[:-1]
111-
y = tf.reshape(y, (-1, self.ndim_source))
112-
y = self._analysis(y)
113-
assert y.shape[-1] == self.ndim_latent
114-
y = tf.reshape(y, tf.concat([batch_shape, [self.ndim_latent]], 0))
115-
return y
116-
117-
def synthesis(self, y):
118-
x = tf.cast(y, self.dtype)
119-
assert x.shape[-1] == self.ndim_latent
120-
batch_shape = tf.shape(x)[:-1]
121-
x = tf.reshape(x, (-1, self.ndim_latent))
122-
x = self._synthesis(x)
123-
assert x.shape[-1] == self.ndim_source
124-
x = tf.reshape(x, tf.concat([batch_shape, [self.ndim_source]], 0))
125-
return x
126-
127142
def encode_decode(self, x, dither_rate, dither_dist, soft_round,
128143
guess_offset=None, offset=0., seed=None):
129144
if guess_offset is None:
@@ -148,6 +163,7 @@ def perturb(inputs, dither, prior, offset):
148163
assert x.shape[-1] == self.ndim_source
149164
y = self.analysis(x)
150165

166+
rates = 0.
151167
prior = self.prior(soft_round=soft_round)
152168

153169
y_dist = perturb(y, dither_dist, prior, offset)
@@ -158,7 +174,7 @@ def perturb(inputs, dither, prior, offset):
158174

159175
x_hat = self.synthesis(y_dist)
160176
log_probs = prior.log_prob(y_rate)
161-
rates = tf.reduce_sum(log_probs, axis=-1) / tf.cast(
177+
rates += tf.reduce_sum(log_probs, axis=-1) / tf.cast(
162178
-tf.math.log(2.), self.dtype)
163179

164180
return y_dist, x_hat, rates

0 commit comments

Comments
 (0)