@@ -589,16 +589,17 @@ class JointDistributionSequentialAutoBatched(
589
589
590
590
def __new__ (cls , * args , ** kwargs ):
591
591
"""Maybe returns a `_JointDistributionSequentialAutobatched`."""
592
- if args :
593
- model = args [0 ]
594
- else :
595
- model = kwargs .get ('model' )
596
-
597
- # Return a `_JointDistributionSequentialAutoBatched` instance if `model`
598
- # contains distributions that are not CompositeTensors.
599
- if not all (isinstance (d , tf .__internal__ .CompositeTensor ) or callable (d )
600
- for d in model ):
601
- return _JointDistributionSequentialAutoBatched (* args , ** kwargs )
592
+ if cls is JointDistributionSequentialAutoBatched :
593
+ if args :
594
+ model = args [0 ]
595
+ else :
596
+ model = kwargs .get ('model' )
597
+
598
+ # Return a `_JointDistributionSequentialAutoBatched` instance if `model`
599
+ # contains distributions that are not CompositeTensors.
600
+ if not all (isinstance (d , tf .__internal__ .CompositeTensor ) or callable (d )
601
+ for d in model ):
602
+ return _JointDistributionSequentialAutoBatched (* args , ** kwargs )
602
603
return super (JointDistributionSequentialAutoBatched , cls ).__new__ (cls )
603
604
604
605
@property
@@ -625,16 +626,17 @@ class JointDistributionNamedAutoBatched(
625
626
626
627
def __new__ (cls , * args , ** kwargs ):
627
628
"""Maybe returns a `_JointDistributionNamedAutoBatched`."""
628
- if args :
629
- model = args [0 ]
630
- else :
631
- model = kwargs .get ('model' )
632
-
633
- # Return a `_JointDistributionNamedAutoBatched` instance if `model` contains
634
- # distributions that are not CompositeTensors.
635
- if not all (isinstance (d , tf .__internal__ .CompositeTensor ) or callable (d )
636
- for d in tf .nest .flatten (model )):
637
- return _JointDistributionNamedAutoBatched (* args , ** kwargs )
629
+ if cls is JointDistributionNamedAutoBatched :
630
+ if args :
631
+ model = args [0 ]
632
+ else :
633
+ model = kwargs .get ('model' )
634
+
635
+ # Return a `_JointDistributionNamedAutoBatched` instance if `model`
636
+ # contains distributions that are not CompositeTensors.
637
+ if not all (isinstance (d , tf .__internal__ .CompositeTensor ) or callable (d )
638
+ for d in tf .nest .flatten (model )):
639
+ return _JointDistributionNamedAutoBatched (* args , ** kwargs )
638
640
return super (JointDistributionNamedAutoBatched , cls ).__new__ (cls )
639
641
640
642
@property
0 commit comments