19
19
from __future__ import print_function
20
20
21
21
import tensorflow .compat .v2 as tf
22
+ from tensorflow_probability .python .bijectors import bijector as bijector_lib
22
23
from tensorflow_probability .python .bijectors import composition
23
24
from tensorflow .python .util import nest # pylint: disable=g-direct-tensorflow-import
24
25
28
29
]
29
30
30
31
31
- class JointMap (composition .Composition ):
32
+ class _JointMap (composition .Composition ):
32
33
"""Bijector which applies a structure of bijectors in parallel.
33
34
34
35
This is the "structured" counterpart to `Chain`. Whereas `Chain` applies an
@@ -92,7 +93,7 @@ def __init__(self,
92
93
self ._nested_structure = self ._no_dependency (
93
94
nest .map_structure (lambda b : None , bijectors ))
94
95
95
- super (JointMap , self ).__init__ (
96
+ super (_JointMap , self ).__init__ (
96
97
bijectors = bijectors ,
97
98
validate_args = validate_args ,
98
99
forward_min_event_ndims = self ._nested_structure ,
@@ -115,3 +116,27 @@ def _walk_inverse(self, step_fn, ys, **kwargs):
115
116
self ._nested_structure ,
116
117
lambda bij , y : step_fn (bij , y , ** kwargs .get (bij .name , {})), # pylint: disable=unnecessary-lambda
117
118
self ._bijectors , ys , check_types = False )
119
+
120
+
121
+ class JointMap (_JointMap , bijector_lib .AutoCompositeTensorBijector ):
122
+
123
+ def __new__ (cls , * args , ** kwargs ):
124
+ """Returns a `_JointMap` any of `bijectors` is not a `CompositeTensor."""
125
+ if cls is JointMap :
126
+ if args :
127
+ bijectors = args [0 ]
128
+ else :
129
+ bijectors = kwargs .get ('bijectors' )
130
+ if bijectors is not None :
131
+ if not all (isinstance (b , tf .__internal__ .CompositeTensor )
132
+ for b in tf .nest .flatten (bijectors )):
133
+ return _JointMap (* args , ** kwargs )
134
+ return super (JointMap , cls ).__new__ (cls )
135
+
136
+
137
+ JointMap .__doc__ = _JointMap .__doc__ + '\n ' + (
138
+ 'If every element of `bijectors` is a `CompositeTensor`, the resulting '
139
+ '`JointMap` bijector is a `CompositeTensor` as well. If any element of '
140
+ '`bijectors` is not a `CompositeTensor`, then a non-`CompositeTensor` '
141
+ '`_JointMap` instance is created instead. Bijector subclasses that inherit '
142
+ 'from `JointMap` will also inherit from `CompositeTensor`.' )
0 commit comments