21
21
import numpy as np
22
22
import tensorflow .compat .v2 as tf
23
23
24
+ from tensorflow_probability .python .bijectors import bijector as bijector_lib
24
25
from tensorflow_probability .python .bijectors import chain
26
+ from tensorflow_probability .python .bijectors import composition
25
27
from tensorflow_probability .python .bijectors import invert
26
28
from tensorflow_probability .python .bijectors import joint_map
27
29
from tensorflow_probability .python .bijectors import split
28
30
from tensorflow_probability .python .internal import assert_util
31
+ from tensorflow_probability .python .internal import auto_composite_tensor
29
32
from tensorflow_probability .python .internal import prefer_static as ps
30
33
from tensorflow_probability .python .internal import tensorshape_util
31
34
@@ -42,8 +45,7 @@ def _get_static_splits(splits):
42
45
return splits if static_splits is None else static_splits
43
46
44
47
45
- # TODO(b/182603117): Enable AutoCompositeTensor once Chain subclasses it.
46
- class Blockwise (chain .Chain ):
48
+ class _Blockwise (composition .Composition ):
47
49
"""Bijector which applies a list of bijectors to blocks of a `Tensor`.
48
50
49
51
More specifically, given [F_0, F_1, ... F_n] which are scalar or vector
@@ -151,9 +153,12 @@ def __init__(self,
151
153
name = 'concat' )
152
154
153
155
self ._maybe_changes_size = maybe_changes_size
154
- super (Blockwise , self ).__init__ (
155
- bijectors = [b_concat , b_joint , b_split ],
156
+ self ._chain = chain .Chain (
157
+ [b_concat , b_joint , b_split ], validate_args = validate_args )
158
+ super (_Blockwise , self ).__init__ (
159
+ bijectors = self ._chain .bijectors ,
156
160
validate_args = validate_args ,
161
+ validate_event_size = True ,
157
162
parameters = parameters ,
158
163
name = name )
159
164
@@ -186,13 +191,13 @@ def inverse_block_sizes(self):
186
191
return self ._b_concat .split_sizes
187
192
188
193
def _forward (self , x , ** kwargs ):
189
- y = super (Blockwise , self )._forward (x , ** kwargs )
194
+ y = super (_Blockwise , self )._forward (x , ** kwargs )
190
195
if not self ._maybe_changes_size :
191
196
tensorshape_util .set_shape (y , x .shape )
192
197
return y
193
198
194
199
def _inverse (self , y , ** kwargs ):
195
- x = super (Blockwise , self )._inverse (y , ** kwargs )
200
+ x = super (_Blockwise , self )._inverse (y , ** kwargs )
196
201
if not self ._maybe_changes_size :
197
202
tensorshape_util .set_shape (x , y .shape )
198
203
return x
@@ -220,19 +225,19 @@ def _inverse_event_shape(self, output_shape):
220
225
def _forward_event_shape_tensor (self , x , ** kwargs ):
221
226
if not self ._maybe_changes_size :
222
227
return x
223
- return super (Blockwise , self )._forward_event_shape_tensor (x , ** kwargs )
228
+ return super (_Blockwise , self )._forward_event_shape_tensor (x , ** kwargs )
224
229
225
230
def _inverse_event_shape_tensor (self , y , ** kwargs ):
226
231
if not self ._maybe_changes_size :
227
232
return y
228
- return super (Blockwise , self )._inverse_event_shape_tensor (y , ** kwargs )
233
+ return super (_Blockwise , self )._inverse_event_shape_tensor (y , ** kwargs )
229
234
230
235
def _walk_forward (self , step_fn , x , ** kwargs ):
231
- return super ( Blockwise , self ). _walk_forward (
236
+ return self . _chain . _walk_forward ( # pylint: disable=protected-access
232
237
step_fn , x , ** {self ._b_joint .name : kwargs })
233
238
234
239
def _walk_inverse (self , step_fn , x , ** kwargs ):
235
- return super ( Blockwise , self ). _walk_inverse (
240
+ return self . _chain . _walk_inverse ( # pylint: disable=protected-access
236
241
step_fn , x , ** {self ._b_joint .name : kwargs })
237
242
238
243
@@ -263,3 +268,32 @@ def _validate_block_sizes(block_sizes, bijectors, validate_args):
263
268
# Set the shape if missing to pass statically known structure to split.
264
269
tensorshape_util .set_shape (block_sizes , [len (bijectors )])
265
270
return block_sizes
271
+
272
+
273
+ @bijector_lib .auto_composite_tensor_bijector
274
+ class Blockwise (_Blockwise , auto_composite_tensor .AutoCompositeTensor ):
275
+
276
+ def __new__ (cls , * args , ** kwargs ):
277
+ """Returns a `_Blockwise` if any of `bijectors` is not `CompositeTensor."""
278
+ if cls is Blockwise :
279
+ if args :
280
+ bijectors = args [0 ]
281
+ elif 'bijectors' in kwargs :
282
+ bijectors = kwargs ['bijectors' ]
283
+ else :
284
+ raise TypeError (
285
+ '`Blockwise.__new__()` is missing argument `bijectors`.' )
286
+
287
+ if not all (isinstance (b , tf .__internal__ .CompositeTensor )
288
+ for b in bijectors ):
289
+ return _Blockwise (* args , ** kwargs )
290
+ return super (Blockwise , cls ).__new__ (cls )
291
+
292
+
293
+ Blockwise .__doc__ = _Blockwise .__doc__ + '\n ' + (
294
+ 'If every element of the `bijectors` list is a `CompositeTensor`, the '
295
+ 'resulting `Blockwise` bijector is a `CompositeTensor` as well. If any '
296
+ 'element of `bijectors` is not a `CompositeTensor`, then a '
297
+ 'non-`CompositeTensor` `_Blockwise` instance is created instead. Bijector '
298
+ 'subclasses that inherit from `Blockwise` will also inherit from '
299
+ '`CompositeTensor`.' )
0 commit comments