@@ -21,26 +21,36 @@ class NTCModel(compression_model.CompressionModel):
21
21
def __init__ (self , analysis , synthesis , prior_type = "deep" ,
22
22
dither = (1 , 1 , 0 , 0 ), soft_round = (1 , 0 ), guess_offset = False ,
23
23
** kwargs ):
24
+ """Initializer.
25
+
26
+ Args:
27
+ analysis: A `Layer` object implementing the analysis transform.
28
+ synthesis: A `Layer` object implementing the synthesis transform.
29
+ prior_type: String. Either 'deep' for `DeepFactorized` prior, or
30
+ 'gsm/gmm/lsm/lmm-X' for Gaussian/Logistic Scale Mixture/Mixture Model
31
+ with X components.
32
+ dither: Sequence of 4 Booleans. Whether to use dither for: rate term
33
+ during training, distortion term during training, rate term during
34
+ testing, distortion term during testing, respectively.
35
+ soft_round: Sequence of 2 Booleans. Whether to use soft rounding during
36
+ training or testing, respectively.
37
+ guess_offset: Boolean. When not using soft rounding, whether to use the
38
+ mode centering heuristic to determine the quantization offset during
39
+ testing.
40
+ **kwargs: Other arguments passed through to `CompressionModel` class.
41
+ """
24
42
super ().__init__ (** kwargs )
25
43
self ._analysis = analysis
26
44
self ._synthesis = synthesis
27
45
self .prior_type = str (prior_type )
28
- # train_rate, train_dist, test_rate, test_dist
29
46
self .dither = tuple (bool (i ) for i in dither )
30
- # train, test
31
47
self .soft_round = tuple (bool (i ) for i in soft_round )
32
48
self .guess_offset = bool (guess_offset )
33
49
34
50
if self .prior_type == "deep" :
35
51
self ._prior = tfc .DeepFactorized (
36
52
batch_shape = [self .ndim_latent ], dtype = self .dtype )
37
- elif self .prior_type == "deep_uniform" :
38
- self ._prior = tfc .DeepFactorized (
39
- batch_shape = [self .ndim_latent ], dtype = self .dtype )
40
- self .log_uniform_width = tf .Variable (
41
- 0 , "log_uniform_width" , dtype = self .dtype )
42
- else :
43
- assert self .prior_type [:4 ] in ("gsm-" , "gmm-" , "lsm-" , "lmm-" )
53
+ elif self .prior_type [:4 ] in ("gsm-" , "gmm-" , "lsm-" , "lmm-" ):
44
54
components = int (self .prior_type [4 :])
45
55
shape = (self .ndim_latent , components )
46
56
self .logits = tf .Variable (tf .random .normal (shape , dtype = self .dtype ))
@@ -50,6 +60,8 @@ def __init__(self, analysis, synthesis, prior_type="deep",
50
60
self .loc = 0.
51
61
else :
52
62
self .loc = tf .Variable (tf .random .normal (shape , dtype = self .dtype ))
63
+ else :
64
+ raise ValueError (f"Unknown prior_type: '{ prior_type } '." )
53
65
54
66
self ._logit_alpha = tf .Variable (- 3 , dtype = self .dtype , name = "logit_alpha" )
55
67
self ._force_alpha = tf .Variable (
@@ -58,8 +70,7 @@ def __init__(self, analysis, synthesis, prior_type="deep",
58
70
def prior (self , soft_round , scale = None , alpha = None , skip_noise = False ):
59
71
if self .prior_type == "deep" :
60
72
prior = self ._prior
61
- else :
62
- assert self .prior_type [:4 ] in ("gsm-" , "gmm-" , "lsm-" , "lmm-" )
73
+ elif self .prior_type [:4 ] in ("gsm-" , "gmm-" , "lsm-" , "lmm-" ):
63
74
cls = tfpd .Normal if self .prior_type .startswith ("g" ) else tfpd .Logistic
64
75
prior = tfpd .MixtureSameFamily (
65
76
mixture_distribution = tfpd .Categorical (logits = self .logits ),
@@ -74,6 +85,34 @@ def prior(self, soft_round, scale=None, alpha=None, skip_noise=False):
74
85
return prior
75
86
return tfc .UniformNoiseAdapter (prior )
76
87
88
+ @property
89
+ def ndim_latent (self ):
90
+ return self ._analysis .output_shape [- 1 ]
91
+
92
+ def analysis (self , x ):
93
+ y = tf .cast (x , self .dtype )
94
+ if y .shape [- 1 ] != self .ndim_source :
95
+ raise ValueError (
96
+ f"Expected { self .ndim_source } trailing dimensions, "
97
+ f"received { y .shape [- 1 ]} ." )
98
+ batch_shape = tf .shape (y )[:- 1 ]
99
+ y = tf .reshape (y , (- 1 , self .ndim_source ))
100
+ y = self ._analysis (y )
101
+ assert y .shape [- 1 ] == self .ndim_latent
102
+ return tf .reshape (y , tf .concat ([batch_shape , [self .ndim_latent ]], 0 ))
103
+
104
+ def synthesis (self , y ):
105
+ x = tf .cast (y , self .dtype )
106
+ if x .shape [- 1 ] != self .ndim_latent :
107
+ raise ValueError (
108
+ f"Expected { self .ndim_latent } trailing dimensions, "
109
+ f"received { x .shape [- 1 ]} ." )
110
+ batch_shape = tf .shape (x )[:- 1 ]
111
+ x = tf .reshape (x , (- 1 , self .ndim_latent ))
112
+ x = self ._synthesis (x )
113
+ assert x .shape [- 1 ] == self .ndim_source
114
+ return tf .reshape (x , tf .concat ([batch_shape , [self .ndim_source ]], 0 ))
115
+
77
116
@property
78
117
def force_alpha (self ):
79
118
return tf .convert_to_tensor (self ._force_alpha )
@@ -100,30 +139,6 @@ def get_logit_alpha():
100
139
self ._logit_alpha .assign (
101
140
tf .cond (value < 0 , lambda : self ._logit_alpha , get_logit_alpha ))
102
141
103
- @property
104
- def ndim_latent (self ):
105
- return self ._analysis .output_shape [- 1 ]
106
-
107
- def analysis (self , x ):
108
- y = tf .cast (x , self .dtype )
109
- assert y .shape [- 1 ] == self .ndim_source
110
- batch_shape = tf .shape (y )[:- 1 ]
111
- y = tf .reshape (y , (- 1 , self .ndim_source ))
112
- y = self ._analysis (y )
113
- assert y .shape [- 1 ] == self .ndim_latent
114
- y = tf .reshape (y , tf .concat ([batch_shape , [self .ndim_latent ]], 0 ))
115
- return y
116
-
117
- def synthesis (self , y ):
118
- x = tf .cast (y , self .dtype )
119
- assert x .shape [- 1 ] == self .ndim_latent
120
- batch_shape = tf .shape (x )[:- 1 ]
121
- x = tf .reshape (x , (- 1 , self .ndim_latent ))
122
- x = self ._synthesis (x )
123
- assert x .shape [- 1 ] == self .ndim_source
124
- x = tf .reshape (x , tf .concat ([batch_shape , [self .ndim_source ]], 0 ))
125
- return x
126
-
127
142
def encode_decode (self , x , dither_rate , dither_dist , soft_round ,
128
143
guess_offset = None , offset = 0. , seed = None ):
129
144
if guess_offset is None :
@@ -148,6 +163,7 @@ def perturb(inputs, dither, prior, offset):
148
163
assert x .shape [- 1 ] == self .ndim_source
149
164
y = self .analysis (x )
150
165
166
+ rates = 0.
151
167
prior = self .prior (soft_round = soft_round )
152
168
153
169
y_dist = perturb (y , dither_dist , prior , offset )
@@ -158,7 +174,7 @@ def perturb(inputs, dither, prior, offset):
158
174
159
175
x_hat = self .synthesis (y_dist )
160
176
log_probs = prior .log_prob (y_rate )
161
- rates = tf .reduce_sum (log_probs , axis = - 1 ) / tf .cast (
177
+ rates + = tf .reduce_sum (log_probs , axis = - 1 ) / tf .cast (
162
178
- tf .math .log (2. ), self .dtype )
163
179
164
180
return y_dist , x_hat , rates
0 commit comments