@@ -49,7 +49,7 @@ class TangentSpace(object):
49
49
50
50
"""
51
51
52
- def transform_general (self , x , f ):
52
+ def transform_general (self , x , f , ** kwargs ):
53
53
"""Returns the density correction, in log space, corresponding to f at x.
54
54
55
55
Also returns a new `TangentSpace` representing the tangent to fM at f(x).
@@ -58,6 +58,7 @@ def transform_general(self, x, f):
58
58
x: `Tensor` (structure). The point at which to calculate the density.
59
59
f: `Bijector` or one of its subclasses. The transformation that requires a
60
60
density correction based on this tangent space.
61
+ **kwargs: Optional keyword arguments as part of the Bijector.
61
62
62
63
Returns:
63
64
log_density: A `Tensor` representing the log density correction of f at x
@@ -69,7 +70,7 @@ def transform_general(self, x, f):
69
70
"""
70
71
raise NotImplementedError
71
72
72
- def transform_dimension_preserving (self , x , f ):
73
+ def transform_dimension_preserving (self , x , f , ** kwargs ):
73
74
"""Same as `transform_general`, assuming f goes from R^n to R^n.
74
75
75
76
Default falls back to `transform_general`, which may be overridden
@@ -78,6 +79,7 @@ def transform_dimension_preserving(self, x, f):
78
79
Args:
79
80
x: same as in `transform_general`.
80
81
f: same as in `transform_general`.
82
+ **kwargs: same as in `transform_general`.
81
83
82
84
Returns:
83
85
log_density: A `Tensor` representing the log density correction of f at x
@@ -88,9 +90,9 @@ def transform_dimension_preserving(self, x, f):
88
90
`transform_general`.
89
91
90
92
"""
91
- return self .transform_general (x , f )
93
+ return self .transform_general (x , f , ** kwargs )
92
94
93
- def transform_projection (self , x , f ):
95
+ def transform_projection (self , x , f , ** kwargs ):
94
96
"""Same as `transform_general`, with f a projection (or its inverse).
95
97
96
98
Default falls back to `transform_general`, which may be overridden
@@ -99,6 +101,7 @@ def transform_projection(self, x, f):
99
101
Args:
100
102
x: same as in `transform_general`.
101
103
f: same as in `transform_general`.
104
+ **kwargs: same as in `transform_general`.
102
105
103
106
Returns:
104
107
log_density: A `Tensor` representing the log density correction of f at x
@@ -108,9 +111,9 @@ def transform_projection(self, x, f):
108
111
NotImplementedError: if the `TangentSpace` subclass does not implement
109
112
`transform_general`.
110
113
"""
111
- return self .transform_general (x , f )
114
+ return self .transform_general (x , f , ** kwargs )
112
115
113
- def transform_coordinatewise (self , x , f ):
116
+ def transform_coordinatewise (self , x , f , ** kwargs ):
114
117
"""Same as `transform_dimension_preserving`, for a coordinatewise f.
115
118
116
119
Default falls back to `transform_dimension_preserving`, which may
@@ -119,6 +122,7 @@ def transform_coordinatewise(self, x, f):
119
122
Args:
120
123
x: same as in `transform_dimension_preserving`.
121
124
f: same as in `transform_dimension_preserving`.
125
+ **kwargs: same as in `transform_dimension_preserving`.
122
126
123
127
Returns:
124
128
log_density: A `Tensor` representing the log density correction of f at x
@@ -129,7 +133,7 @@ def transform_coordinatewise(self, x, f):
129
133
`transform_dimension_preserving`.
130
134
131
135
"""
132
- return self .transform_dimension_preserving (x , f )
136
+ return self .transform_dimension_preserving (x , f , ** kwargs )
133
137
134
138
135
139
def unit_basis ():
@@ -161,31 +165,32 @@ def __init__(self, axis_mask):
161
165
"""
162
166
self .axis_mask = axis_mask
163
167
164
- def transform_general (self , x , f ):
168
+ def transform_general (self , x , f , ** kwargs ):
165
169
as_general_space = GeneralSpace (unit_basis_on (self .axis_mask ), 1 )
166
- return as_general_space .transform_general (x , f )
170
+ return as_general_space .transform_general (x , f , ** kwargs )
167
171
168
- def transform_projection (self , x , f ):
172
+ def transform_projection (self , x , f , ** kwargs ):
169
173
if not hasattr (f , 'experimental_update_live_dimensions' ):
170
174
msg = ('When calling `transform_projection` the Bijector must implement '
171
175
'the `experimental_update_live_dimensions` method.' )
172
176
raise NotImplementedError (msg )
173
- new_live_dimensions = f .experimental_update_live_dimensions (self .axis_mask )
177
+ new_live_dimensions = f .experimental_update_live_dimensions (
178
+ self .axis_mask , ** kwargs )
174
179
if all (tf .get_static_value (new_live_dimensions )):
175
180
# Special-case a bijector (direction) that knows that the result
176
181
# of the projection will be a full space
177
182
return 0 , FullSpace ()
178
183
else :
179
184
return 0 , AxisAlignedSpace (new_live_dimensions )
180
185
181
- def transform_coordinatewise (self , x , f ):
186
+ def transform_coordinatewise (self , x , f , ** kwargs ):
182
187
# TODO(pravnar): compute the derivative of f along x along the
183
188
# live dimensions.
184
189
raise NotImplementedError
185
190
186
191
187
- def jacobian_determinant (x , f ):
188
- return f .forward_log_det_jacobian (x )
192
+ def jacobian_determinant (x , f , ** kwargs ):
193
+ return f .forward_log_det_jacobian (x , ** kwargs )
189
194
190
195
191
196
class FullSpace (TangentSpace ):
@@ -197,16 +202,17 @@ class FullSpace(TangentSpace):
197
202
at all.
198
203
"""
199
204
200
- def transform_general (self , x , f ):
205
+ def transform_general (self , x , f , ** kwargs ):
201
206
"""If the bijector is weird, fall back to the general case."""
202
207
as_general_space = GeneralSpace (unit_basis (), 1 )
203
- return as_general_space .transform_general (x , f )
208
+ return as_general_space .transform_general (x , f , ** kwargs )
204
209
205
- def transform_dimension_preserving (self , x , f ):
206
- return jacobian_determinant (x , f ), FullSpace ()
210
+ def transform_dimension_preserving (self , x , f , ** kwargs ):
211
+ return jacobian_determinant (x , f , ** kwargs ), FullSpace ()
207
212
208
- def transform_projection (self , x , f ):
209
- return AxisAlignedSpace (tf .ones_like (x )).transform_projection (x , f )
213
+ def transform_projection (self , x , f , ** kwargs ):
214
+ return AxisAlignedSpace (tf .ones_like (x )).transform_projection (
215
+ x , f , ** kwargs )
210
216
211
217
212
218
def volume_coefficient (basis ):
@@ -223,7 +229,7 @@ def __init__(self, basis, computed_volume=None):
223
229
computed_volume = volume_coefficient (basis )
224
230
self .volume = computed_volume
225
231
226
- def transform_general (self , x , f ):
232
+ def transform_general (self , x , f , ** kwargs ):
227
233
raise NotImplementedError
228
234
229
235
@@ -236,7 +242,7 @@ class ZeroSpace(TangentSpace):
236
242
237
243
"""
238
244
239
- def transform_general (self , x , f ):
245
+ def transform_general (self , x , f , ** kwargs ):
240
246
del x , f
241
247
return 0 , ZeroSpace ()
242
248
0 commit comments