19
19
20
20
from tensorflow_probability .python import math as tfp_math
21
21
from tensorflow_probability .python .bijectors import bijector
22
+ from tensorflow_probability .python .internal import callable_util
22
23
from tensorflow_probability .python .internal import custom_gradient as tfp_custom_gradient
23
24
from tensorflow_probability .python .internal import prefer_static as ps
25
+ from tensorflow_probability .python .internal import tensorshape_util
24
26
25
27
__all__ = ['ScalarFunctionWithInferredInverse' ]
26
28
@@ -35,6 +37,7 @@ def __init__(self,
35
37
max_iterations = 50 ,
36
38
require_convergence = True ,
37
39
additional_scalar_parameters_requiring_gradients = (),
40
+ dtype = None ,
38
41
validate_args = False ,
39
42
name = 'scalar_function_with_inferred_inverse' ):
40
43
"""Initialize the ScalarFunctionWithInferredInverse bijector.
@@ -72,6 +75,9 @@ def __init__(self,
72
75
anything in the closure of `fn`) will not, in general, receive
73
76
gradients.
74
77
Default value: `()`.
78
+ dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
79
+ enforced.
80
+ Default value: `None`.
75
81
validate_args: Python `bool` indicating whether arguments should be
76
82
checked for correctness.
77
83
name: Python `str` name given to ops managed by this object.
@@ -91,14 +97,14 @@ def __init__(self,
91
97
# VJPs and JVPs can be computed efficiently using actual matrix ops.
92
98
self ._additional_scalar_parameters_requiring_gradients = (
93
99
additional_scalar_parameters_requiring_gradients )
94
- self ._cached_fn_batch_shape = None
95
100
96
101
self ._bound_fn = (
97
102
lambda x : fn (x , * additional_scalar_parameters_requiring_gradients ))
98
103
self ._inverse = self ._wrap_inverse_with_implicit_gradient ()
99
104
100
105
super (ScalarFunctionWithInferredInverse , self ).__init__ (
101
106
parameters = parameters ,
107
+ dtype = dtype ,
102
108
forward_min_event_ndims = 0 ,
103
109
inverse_min_event_ndims = 0 ,
104
110
validate_args = validate_args ,
@@ -129,15 +135,25 @@ def bound_fn(self):
129
135
"""Forward `fn` with any extra args bound, so that `y = bound_fn(x)`."""
130
136
return self ._bound_fn
131
137
132
- def _fn_batch_shape (self ):
133
- if self ._cached_fn_batch_shape is None :
134
- # Evaluating at a scalar value (0.) exposes the function's batch shape.
135
- # For example, evaluating
136
- # `fn = lambda x: x * constant([1., 2., 3.])`
137
- # returns a result of shape `[3]`.
138
- self ._cached_fn_batch_shape = ps .shape (
139
- self .bound_fn (self .domain_constraint_fn (0. ))) # pylint: disable=not-callable
140
- return self ._cached_fn_batch_shape
138
+ def _batch_shape (self , x_event_ndims ):
139
+ try :
140
+ # Trace the function to extract its batch shape without executing it.
141
+ fn_shape = callable_util .get_output_spec (
142
+ lambda x : self .bound_fn (self .domain_constraint_fn (x )), # pylint: disable=not-callable
143
+ tf .TensorSpec ([], dtype = self .dtype if self .dtype else tf .float32 )
144
+ ).shape
145
+ except TypeError : # `dtype` wasn't specified.
146
+ return tf .TensorShape (None )
147
+
148
+ fn_rank = tensorshape_util .rank (fn_shape )
149
+ if fn_rank is not None :
150
+ return fn_shape [:fn_rank - x_event_ndims ]
151
+ return fn_shape
152
+
153
+ def _batch_shape_tensor (self , x_event_ndims ):
154
+ fn_shape = ps .shape (
155
+ self .bound_fn (self .domain_constraint_fn (0. ))) # pylint: disable=not-callable
156
+ return fn_shape [:ps .rank_from_shape (fn_shape ) - x_event_ndims ]
141
157
142
158
def _forward (self , x ):
143
159
return self .bound_fn (x )
@@ -220,8 +236,8 @@ def _arg_broadcasting_wrapped_inverse(y):
220
236
# TODO(davmre): Do gradient reductions directly in the VJP using
221
237
# `tf.raw_ops.BroadcastGradientArgs` so we can remove this wrapper
222
238
# and avoid spurious broadcasting.
223
- full_batch_shape = ps .broadcast_shape (self . _fn_batch_shape (),
224
- ps .shape (y ))
239
+ full_batch_shape = ps .broadcast_shape (
240
+ self . experimental_batch_shape_tensor (), ps .shape (y ))
225
241
args = [tf .broadcast_to (arg , full_batch_shape ) for arg in args ]
226
242
return _inverse_with_gradient (y , * args )
227
243
0 commit comments