Skip to content

Commit 19c8f11

Browse files
davmretensorflower-gardener
authored andcommitted
Replace deprecation of parameter_properties inheritance with a warning.
This also removes a spurious deprecation warning that appears when `JointDistribution`s are used in public colabs. (JDs have a multi-level hierarchy in which no class defines its own _parameter_properties, but since they all inherit the base class def that raises `NotImplementedError`, there's no inheritance problem). Justification for this change: there are legitimate 'quick-and-dirty' uses of parameter_properties inheritance, even if we wouldn't do it in TFP. For example, somewhere in the DeepMind silo is a Normal distribution that takes a log_scale rather than a scale: class NormalWithLogScale(tfd.Normal): def __init__(self, loc, log_scale): super().__init__(loc=loc, scale=tf.exp(log_scale)) Ignoring whether this is the best way to accomplish any particular goal, from a general Pythonic standpoint one might expect that this class would at least be basically functional. It may not support batch slicing or AutoCompositeTensor, but one should at least be able to call sample and log_prob, which means it should at least define properties like batch_shape. But now that batch shape depends on parameter_properties (as of cl/373590501), breaking parameter_properties inheritance means breaking batch_shape inheritance. To allow quick subclasses like this, I propose we simply warn when an inherited `parameter_properties` is called. In this example, the batch shape would be computed (correctly) using the base Normal parameters, just as if an explicit batch_shape method had been inherited. For full functionality including batch slicing, CompositeTensor, etc., a subclass would need to both (a) set self.parameters = dict(locals()) in its own constructor, and (b) define its own _parameter_properties. I've tried to articulate these requirements in the warning message. PiperOrigin-RevId: 374554087
1 parent 3632afb commit 19c8f11

File tree

2 files changed

+61
-39
lines changed

2 files changed

+61
-39
lines changed

tensorflow_probability/python/distributions/distribution.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import contextlib
2424
import functools
2525
import inspect
26+
import logging
2627
import types
2728

2829
import decorator
@@ -275,40 +276,61 @@ def __new__(mcs, classname, baseclasses, attrs):
275276
return super(_DistributionMeta, mcs).__new__(
276277
mcs, classname, baseclasses, attrs)
277278

278-
# Subclasses shouldn't inherit their parents' `_parameter_properties`,
279-
# since (in general) they'll have different parameters. Exceptions (for
280-
# convenience) are:
279+
# Warn when a subclass inherits `_parameter_properties` from its parent
280+
# (this is unsafe, since the subclass will in general have different
281+
# parameters). Exceptions are:
281282
# - Subclasses that don't define their own `__init__` (handled above by
282283
# the short-circuit when `default_init is None`).
283284
# - Subclasses that define a passthrough `__init__(self, *args, **kwargs)`.
284-
# - Direct children of `Distribution`, since the inherited method just
285-
# raises a NotImplementedError.
285+
# pylint: disable=protected-access
286286
init_argspec = tf_inspect.getfullargspec(default_init)
287287
if ('_parameter_properties' not in attrs
288-
and base != Distribution
289288
# Passthrough exception: may only take `self` and at least one of
290289
# `*args` and `**kwargs`.
291290
and (len(init_argspec.args) > 1
292291
or not (init_argspec.varargs or init_argspec.varkw))):
293-
# TODO(b/183457779) remove warning and raise `NotImplementedError`.
294-
attrs['_parameter_properties'] = deprecation.deprecated(
295-
date='2021-07-01',
296-
instructions="""
297-
Calling `_parameter_properties` on subclass {classname} that redefines the
298-
parent ({basename}) `__init__` is unsafe and will raise an error in the future.
299-
Please implement an explicit `_parameter_properties` for the subclass. If the
300-
subclass `__init__` takes the same parameters as the parent, you may use the
301-
placeholder implementation:
302292

303-
@classmethod
304-
def _parameter_properties(cls, dtype, num_classes=None):
305-
return {basename}._parameter_properties(
306-
dtype=dtype, num_classes=num_classes)
293+
@functools.wraps(base._parameter_properties)
294+
def wrapped_properties(*args, **kwargs): # pylint: disable=missing-docstring
295+
"""Wrapper to warn if `parameter_properties` is inherited."""
296+
properties = base._parameter_properties(*args, **kwargs)
297+
# Warn *after* calling the base method, so that we don't bother warning
298+
# if it just raised NotImplementedError anyway.
299+
logging.warning("""
300+
Distribution subclass %s inherits `_parameter_properties from its parent (%s)
301+
while also redefining `__init__`. The inherited annotations cover the following
302+
parameters: %s. It is likely that these do not match the subclass parameters.
303+
This may lead to errors when computing batch shapes, slicing into batch
304+
dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor
305+
(e.g., when it is passed or returned from a `tf.function`), and possibly other
306+
cases. The recommended pattern for distribution subclasses is to define a new
307+
`_parameter_properties` method with the subclass parameters, and to store the
308+
corresponding parameter values as `self._parameters` in `__init__`, after
309+
calling the superclass constructor:
310+
311+
```
312+
class MySubclass(tfd.SomeDistribution):
313+
314+
def __init__(self, param_a, param_b):
315+
parameters = dict(locals())
316+
# ... do subclass initialization ...
317+
super(MySubclass, self).__init__(**base_class_params)
318+
# Ensure that the subclass (not base class) parameters are stored.
319+
self._parameters = parameters
320+
321+
def _parameter_properties(self, dtype, num_classes=None):
322+
return dict(
323+
# Annotations may optionally specify properties, such as `event_ndims`,
324+
# `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
325+
# the `ParameterProperties` documentation for details.
326+
param_a=tfp.util.ParameterProperties(),
327+
param_b=tfp.util.ParameterProperties())
328+
```
329+
""", classname, base.__name__, str(properties.keys()))
330+
return properties
307331

308-
""".format(classname=classname,
309-
basename=base.__name__))(base._parameter_properties)
332+
attrs['_parameter_properties'] = wrapped_properties
310333

311-
# pylint: disable=protected-access
312334
# For a comparison of different methods for wrapping functions, see:
313335
# https://hynek.me/articles/decorators/
314336
@decorator.decorator
@@ -656,10 +678,7 @@ def _composite_tensor_shape_params(self):
656678
@classmethod
657679
def _parameter_properties(cls, dtype, num_classes=None):
658680
raise NotImplementedError(
659-
'_parameter_properties` is not implemented: {}. '
660-
'Note that subclasses that redefine `__init__` are not assumed to '
661-
'share parameters with their parent class and must provide a separate '
662-
'implementation.'.format(cls.__name__))
681+
'_parameter_properties` is not implemented: {}.'.format(cls.__name__))
663682

664683
@classmethod
665684
def parameter_properties(cls, dtype=tf.float32, num_classes=None):

tensorflow_probability/python/distributions/distribution_test.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import collections
2020
# Dependency imports
2121

22+
from absl import logging
2223
from absl.testing import parameterized
2324

2425
import numpy as np
@@ -29,8 +30,6 @@
2930
from tensorflow_probability.python.internal import test_util
3031

3132
from tensorflow.python.framework import test_util as tf_test_util # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
32-
from tensorflow.python.platform import test as tf_test # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
33-
from tensorflow.python.platform import tf_logging # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
3433

3534

3635
class TupleDistribution(tfd.Distribution):
@@ -617,10 +616,7 @@ def normal_differential_entropy(scale):
617616
self.evaluate(normal_differential_entropy(scale)),
618617
err=1e-5)
619618

620-
@test_util.jax_disable_test_missing_functionality('tf_logging')
621-
@tf_test.mock.patch.object(tf_logging, 'warning', autospec=True)
622-
def testParameterPropertiesNotInherited(self, mock_warning):
623-
# TODO(b/183457779) Test for NotImplementedError (rather than just warning).
619+
def testParameterPropertiesNotInherited(self):
624620

625621
# Subclasses that don't redefine __init__ can inherit properties.
626622
class NormalTrivialSubclass(tfd.Normal):
@@ -640,20 +636,27 @@ class MyDistribution(tfd.Distribution):
640636
def __init__(self, param1, param2):
641637
pass
642638

643-
NormalTrivialSubclass.parameter_properties()
644-
NormalWithPassThroughInit.parameter_properties()
645-
with self.assertRaises(NotImplementedError):
646-
MyDistribution.parameter_properties()
647-
self.assertEqual(0, mock_warning.call_count)
639+
with self.assertLogs(level=logging.WARNING) as log:
640+
NormalTrivialSubclass.parameter_properties()
641+
NormalWithPassThroughInit.parameter_properties()
642+
with self.assertRaises(NotImplementedError):
643+
MyDistribution.parameter_properties()
644+
with self.assertRaises(NotImplementedError):
645+
# Ensure that the unimplemented JD propertoes don't raise a warning.
646+
tfd.JointDistributionCoroutine.parameter_properties()
647+
logging.warning('assertLogs context requires at least one warning.')
648+
# Assert that no warnings occurred other than the dummy warning.
649+
self.assertLen(log.records, 1)
648650

649651
class NormalWithExtraParam(tfd.Normal):
650652

651653
def __init__(self, extra_param, *args, **kwargs):
652654
self._extra_param = extra_param
653655
super(NormalWithExtraParam, self).__init__(*args, **kwargs)
654656

655-
NormalWithExtraParam.parameter_properties()
656-
self.assertEqual(1, mock_warning.call_count)
657+
with self.assertLogs(level=logging.WARNING) as log:
658+
NormalWithExtraParam.parameter_properties()
659+
self.assertLen(log.records, 1)
657660

658661

659662
@test_util.test_all_tf_execution_regimes

0 commit comments

Comments
 (0)