Skip to content

Commit 2d08c76

Browse files
davmretensorflower-gardener
authored andcommitted
Add copy method to Bijectors.
This duplicates the interface and behavior of Distribution.copy(). The implementation was copied from Distribution, then slightly simplified because Bijectors don't currently support batch slicing. PiperOrigin-RevId: 380024162
1 parent a28c3d5 commit 2d08c76

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,28 @@ def __call__(self, value, name=None, **kwargs):
943943

944944
return self.forward(value, name=name or 'forward', **kwargs)
945945

946+
def copy(self, **override_parameters_kwargs):
947+
"""Creates a copy of the bijector.
948+
949+
Note: the copy bijector may continue to depend on the original
950+
initialization arguments.
951+
952+
Args:
953+
**override_parameters_kwargs: String/value dictionary of initialization
954+
arguments to override with new values.
955+
956+
Returns:
957+
bijector: A new instance of `type(self)` initialized from the union
958+
of self.parameters and override_parameters_kwargs, i.e.,
959+
`dict(self.parameters, **override_parameters_kwargs)`.
960+
"""
961+
parameters = dict(self.parameters, **override_parameters_kwargs)
962+
b = type(self)(**parameters)
963+
# pylint: disable=protected-access
964+
b._parameters = self._no_dependency(parameters)
965+
# pylint: enable=protected-access
966+
return b
967+
946968
def _forward_event_shape_tensor(self, input_shape):
947969
"""Subclass implementation for `forward_event_shape_tensor` function."""
948970
# By default, we assume event_shape is unchanged.

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,28 @@ def _parameter_properties(cls, dtype, num_classes=None):
183183
self.assertAllClose(ildj, true_ildj)
184184
self.assertAllClose(ildj, -fldj)
185185

186+
def testCopyExtraArgs(self):
187+
# Note: we cannot easily test all bijectors since each requires
188+
# different initialization arguments. We therefore spot test a few.
189+
sigmoid = tfb.Sigmoid(low=-1., high=2., validate_args=True)
190+
self.assertEqual(sigmoid.parameters, sigmoid.copy().parameters)
191+
chain = tfb.Chain(
192+
[
193+
tfb.Softplus(hinge_softness=[1., 2.], validate_args=True),
194+
tfb.MatrixInverseTriL(validate_args=True)
195+
], validate_args=True)
196+
self.assertEqual(chain.parameters, chain.copy().parameters)
197+
198+
def testCopyOverride(self):
199+
sigmoid = tfb.Sigmoid(low=-1., high=2., validate_args=True)
200+
self.assertEqual(sigmoid.parameters, sigmoid.copy().parameters)
201+
unused_sigmoid_copy = sigmoid.copy(validate_args=False)
202+
base_params = sigmoid.parameters.copy()
203+
copy_params = sigmoid.copy(validate_args=False).parameters.copy()
204+
self.assertNotEqual(
205+
base_params.pop('validate_args'), copy_params.pop('validate_args'))
206+
self.assertEqual(base_params, copy_params)
207+
186208

187209
class IntentionallyMissingError(Exception):
188210
pass

0 commit comments

Comments
 (0)