Skip to content

Commit 59d9f18

Browse files
Johannes Ballécopybara-github
authored andcommitted
Fixes tfp.Distribution._parameter_properties inheritance warning.
PiperOrigin-RevId: 424338195 Change-Id: If363b686b74593e93082e0a858e43ffd1263315a
1 parent aff2469 commit 59d9f18

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

tensorflow_compression/python/distributions/deep_factorized.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ def _upper_tail(self, tail_mass):
254254
return helpers.estimate_tails(
255255
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
256256

257+
@classmethod
258+
def _parameter_properties(cls, dtype=tf.float32, num_classes=None):
259+
raise NotImplementedError(
260+
f"`{cls.__name__}` does not implement `_parameter_properties`.")
261+
257262

258263
class NoisyDeepFactorized(uniform_noise.UniformNoiseAdapter):
259264
"""DeepFactorized that is convolved with uniform noise."""

tensorflow_compression/python/distributions/round_adapters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ def _upper_tail(self, tail_mass):
157157
return self.transform(helpers.upper_tail(self.base, tail_mass))
158158
# pylint: enable=protected-access
159159

160+
@classmethod
161+
def _parameter_properties(cls, dtype=tf.float32, num_classes=None):
162+
raise NotImplementedError(
163+
f"`{cls.__name__}` does not implement `_parameter_properties`.")
164+
160165

161166
class RoundAdapter(MonotonicAdapter):
162167
"""Continuous density function + round."""

tensorflow_compression/python/distributions/uniform_noise.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def _lower_tail(self, tail_mass):
189189
def _upper_tail(self, tail_mass):
190190
return helpers.upper_tail(self.base, tail_mass)
191191

192+
@classmethod
193+
def _parameter_properties(cls, dtype=tf.float32, num_classes=None):
194+
raise NotImplementedError(
195+
f"`{cls.__name__}` does not implement `_parameter_properties`.")
196+
192197

193198
class NoisyMixtureSameFamily(tfp.distributions.MixtureSameFamily):
194199
"""Mixture of distributions with additive i.i.d. uniform noise."""
@@ -211,6 +216,18 @@ def base(self):
211216
"""The base distribution (without uniform noise)."""
212217
return self._base
213218

219+
def _batch_shape_tensor(self):
220+
return self.base.batch_shape_tensor()
221+
222+
def _batch_shape(self):
223+
return self.base.batch_shape
224+
225+
def _event_shape_tensor(self):
226+
return self.base.event_shape_tensor()
227+
228+
def _event_shape(self):
229+
return self.base.event_shape
230+
214231
def _quantization_offset(self):
215232
# Picks the "peakiest" of the component quantization offsets.
216233
offsets = helpers.quantization_offset(self.components_distribution)
@@ -225,6 +242,11 @@ def _lower_tail(self, tail_mass):
225242
def _upper_tail(self, tail_mass):
226243
return helpers.upper_tail(self.base, tail_mass)
227244

245+
@classmethod
246+
def _parameter_properties(cls, dtype=tf.float32, num_classes=None):
247+
raise NotImplementedError(
248+
f"`{cls.__name__}` does not implement `_parameter_properties`.")
249+
228250

229251
class NoisyNormal(UniformNoiseAdapter):
230252
"""Gaussian distribution with additive i.i.d. uniform noise."""

0 commit comments

Comments
 (0)