|
23 | 23 | import contextlib
|
24 | 24 | import functools
|
25 | 25 | import inspect
|
| 26 | +import logging |
26 | 27 | import types
|
27 | 28 |
|
28 | 29 | import decorator
|
@@ -275,40 +276,61 @@ def __new__(mcs, classname, baseclasses, attrs):
|
275 | 276 | return super(_DistributionMeta, mcs).__new__(
|
276 | 277 | mcs, classname, baseclasses, attrs)
|
277 | 278 |
|
278 |
| - # Subclasses shouldn't inherit their parents' `_parameter_properties`, |
279 |
| - # since (in general) they'll have different parameters. Exceptions (for |
280 |
| - # convenience) are: |
| 279 | + # Warn when a subclass inherits `_parameter_properties` from its parent |
| 280 | + # (this is unsafe, since the subclass will in general have different |
| 281 | + # parameters). Exceptions are: |
281 | 282 | # - Subclasses that don't define their own `__init__` (handled above by
|
282 | 283 | # the short-circuit when `default_init is None`).
|
283 | 284 | # - Subclasses that define a passthrough `__init__(self, *args, **kwargs)`.
|
284 |
| - # - Direct children of `Distribution`, since the inherited method just |
285 |
| - # raises a NotImplementedError. |
| 285 | + # pylint: disable=protected-access |
286 | 286 | init_argspec = tf_inspect.getfullargspec(default_init)
|
287 | 287 | if ('_parameter_properties' not in attrs
|
288 |
| - and base != Distribution |
289 | 288 | # Passthrough exception: may only take `self` and at least one of
|
290 | 289 | # `*args` and `**kwargs`.
|
291 | 290 | and (len(init_argspec.args) > 1
|
292 | 291 | or not (init_argspec.varargs or init_argspec.varkw))):
|
293 |
| - # TODO(b/183457779) remove warning and raise `NotImplementedError`. |
294 |
| - attrs['_parameter_properties'] = deprecation.deprecated( |
295 |
| - date='2021-07-01', |
296 |
| - instructions=""" |
297 |
| -Calling `_parameter_properties` on subclass {classname} that redefines the |
298 |
| -parent ({basename}) `__init__` is unsafe and will raise an error in the future. |
299 |
| -Please implement an explicit `_parameter_properties` for the subclass. If the |
300 |
| -subclass `__init__` takes the same parameters as the parent, you may use the |
301 |
| -placeholder implementation: |
302 | 292 |
|
303 |
| - @classmethod |
304 |
| - def _parameter_properties(cls, dtype, num_classes=None): |
305 |
| - return {basename}._parameter_properties( |
306 |
| - dtype=dtype, num_classes=num_classes) |
| 293 | + @functools.wraps(base._parameter_properties) |
| 294 | + def wrapped_properties(*args, **kwargs): # pylint: disable=missing-docstring |
| 295 | + """Wrapper to warn if `parameter_properties` is inherited.""" |
| 296 | + properties = base._parameter_properties(*args, **kwargs) |
| 297 | + # Warn *after* calling the base method, so that we don't bother warning |
| 298 | + # if it just raised NotImplementedError anyway. |
| 299 | + logging.warning(""" |
| 300 | +Distribution subclass %s inherits `_parameter_properties from its parent (%s) |
| 301 | +while also redefining `__init__`. The inherited annotations cover the following |
| 302 | +parameters: %s. It is likely that these do not match the subclass parameters. |
| 303 | +This may lead to errors when computing batch shapes, slicing into batch |
| 304 | +dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor |
| 305 | +(e.g., when it is passed or returned from a `tf.function`), and possibly other |
| 306 | +cases. The recommended pattern for distribution subclasses is to define a new |
| 307 | +`_parameter_properties` method with the subclass parameters, and to store the |
| 308 | +corresponding parameter values as `self._parameters` in `__init__`, after |
| 309 | +calling the superclass constructor: |
| 310 | +
|
| 311 | +``` |
| 312 | +class MySubclass(tfd.SomeDistribution): |
| 313 | +
|
| 314 | + def __init__(self, param_a, param_b): |
| 315 | + parameters = dict(locals()) |
| 316 | + # ... do subclass initialization ... |
| 317 | + super(MySubclass, self).__init__(**base_class_params) |
| 318 | + # Ensure that the subclass (not base class) parameters are stored. |
| 319 | + self._parameters = parameters |
| 320 | +
|
| 321 | + def _parameter_properties(self, dtype, num_classes=None): |
| 322 | + return dict( |
| 323 | + # Annotations may optionally specify properties, such as `event_ndims`, |
| 324 | + # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see |
| 325 | + # the `ParameterProperties` documentation for details. |
| 326 | + param_a=tfp.util.ParameterProperties(), |
| 327 | + param_b=tfp.util.ParameterProperties()) |
| 328 | +``` |
| 329 | +""", classname, base.__name__, str(properties.keys())) |
| 330 | + return properties |
307 | 331 |
|
308 |
| -""".format(classname=classname, |
309 |
| - basename=base.__name__))(base._parameter_properties) |
| 332 | + attrs['_parameter_properties'] = wrapped_properties |
310 | 333 |
|
311 |
| - # pylint: disable=protected-access |
312 | 334 | # For a comparison of different methods for wrapping functions, see:
|
313 | 335 | # https://hynek.me/articles/decorators/
|
314 | 336 | @decorator.decorator
|
@@ -656,10 +678,7 @@ def _composite_tensor_shape_params(self):
|
656 | 678 | @classmethod
|
657 | 679 | def _parameter_properties(cls, dtype, num_classes=None):
|
658 | 680 | raise NotImplementedError(
|
659 |
| - '_parameter_properties` is not implemented: {}. ' |
660 |
| - 'Note that subclasses that redefine `__init__` are not assumed to ' |
661 |
| - 'share parameters with their parent class and must provide a separate ' |
662 |
| - 'implementation.'.format(cls.__name__)) |
| 681 | + '_parameter_properties` is not implemented: {}.'.format(cls.__name__)) |
663 | 682 |
|
664 | 683 | @classmethod
|
665 | 684 | def parameter_properties(cls, dtype=tf.float32, num_classes=None):
|
|
0 commit comments