Skip to content

Commit 397a53e

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Fix a bug that caused user-defined subclasses of JointDistributions to error in __new__.
PiperOrigin-RevId: 474175053
1 parent f563082 commit 397a53e

File tree

3 files changed

+39
-35
lines changed

3 files changed

+39
-35
lines changed

tensorflow_probability/python/distributions/joint_distribution_auto_batched.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -589,16 +589,17 @@ class JointDistributionSequentialAutoBatched(
589589

590590
def __new__(cls, *args, **kwargs):
591591
"""Maybe returns a `_JointDistributionSequentialAutobatched`."""
592-
if args:
593-
model = args[0]
594-
else:
595-
model = kwargs.get('model')
596-
597-
# Return a `_JointDistributionSequentialAutoBatched` instance if `model`
598-
# contains distributions that are not CompositeTensors.
599-
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
600-
for d in model):
601-
return _JointDistributionSequentialAutoBatched(*args, **kwargs)
592+
if cls is JointDistributionSequentialAutoBatched:
593+
if args:
594+
model = args[0]
595+
else:
596+
model = kwargs.get('model')
597+
598+
# Return a `_JointDistributionSequentialAutoBatched` instance if `model`
599+
# contains distributions that are not CompositeTensors.
600+
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
601+
for d in model):
602+
return _JointDistributionSequentialAutoBatched(*args, **kwargs)
602603
return super(JointDistributionSequentialAutoBatched, cls).__new__(cls)
603604

604605
@property
@@ -625,16 +626,17 @@ class JointDistributionNamedAutoBatched(
625626

626627
def __new__(cls, *args, **kwargs):
627628
"""Maybe returns a `_JointDistributionNamedAutoBatched`."""
628-
if args:
629-
model = args[0]
630-
else:
631-
model = kwargs.get('model')
632-
633-
# Return a `_JointDistributionNamedAutoBatched` instance if `model` contains
634-
# distributions that are not CompositeTensors.
635-
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
636-
for d in tf.nest.flatten(model)):
637-
return _JointDistributionNamedAutoBatched(*args, **kwargs)
629+
if cls is JointDistributionNamedAutoBatched:
630+
if args:
631+
model = args[0]
632+
else:
633+
model = kwargs.get('model')
634+
635+
# Return a `_JointDistributionNamedAutoBatched` instance if `model`
636+
# contains distributions that are not CompositeTensors.
637+
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
638+
for d in tf.nest.flatten(model)):
639+
return _JointDistributionNamedAutoBatched(*args, **kwargs)
638640
return super(JointDistributionNamedAutoBatched, cls).__new__(cls)
639641

640642
@property

tensorflow_probability/python/distributions/joint_distribution_named.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,15 @@ class JointDistributionNamed(_JointDistributionNamed,
464464

465465
def __new__(cls, *args, **kwargs):
466466
"""Returns a `_JointDistributionNamed` if `model` contains non-CT dists."""
467-
if args:
468-
model = args[0]
469-
else:
470-
model = kwargs.get('model')
467+
if cls is JointDistributionNamed:
468+
if args:
469+
model = args[0]
470+
else:
471+
model = kwargs.get('model')
471472

472-
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
473-
for d in tf.nest.flatten(model)):
474-
return _JointDistributionNamed(*args, **kwargs)
473+
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
474+
for d in tf.nest.flatten(model)):
475+
return _JointDistributionNamed(*args, **kwargs)
475476
return super(JointDistributionNamed, cls).__new__(cls)
476477

477478
@property

tensorflow_probability/python/distributions/joint_distribution_sequential.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -728,14 +728,15 @@ class JointDistributionSequential(_JointDistributionSequential,
728728

729729
def __new__(cls, *args, **kwargs):
730730
"""Returns a `_JointDistributionSequential` if `model` has non-CT dists."""
731-
if args:
732-
model = args[0]
733-
else:
734-
model = kwargs.get('model')
735-
736-
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
737-
for d in model):
738-
return _JointDistributionSequential(*args, **kwargs)
731+
if cls is JointDistributionSequential:
732+
if args:
733+
model = args[0]
734+
else:
735+
model = kwargs.get('model')
736+
737+
if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
738+
for d in model):
739+
return _JointDistributionSequential(*args, **kwargs)
739740
return super(JointDistributionSequential, cls).__new__(cls)
740741

741742
@property

0 commit comments

Comments
 (0)