Skip to content

Commit 495e3f2

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Add TypeSpec for tfp.util.DeferredTensor.
PiperOrigin-RevId: 374957904
1 parent 9d4fc05 commit 495e3f2

File tree

2 files changed

+335
-89
lines changed

2 files changed

+335
-89
lines changed

tensorflow_probability/python/util/deferred_tensor.py

Lines changed: 186 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -530,40 +530,8 @@ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
530530
read_value=read_value)
531531

532532

533-
@type_spec.register('tfp.util.TransformedVariableSpec')
534-
class _TransformedVariableSpec(type_spec.BatchableTypeSpec):
535-
"""`tf.TypeSpec` for `tfp.util.TransformedVariable`."""
536-
537-
__slots__ = ('_input_spec', '_transform_or_spec', '_dtype', '_name', '_specs',
538-
'_unique_id_params', '_transform_is_composite')
539-
540-
def __init__(self, input_spec, transform_or_spec, dtype, name):
541-
"""Initializes a new `_TransformedVariableSpec`.
542-
543-
Args:
544-
input_spec: `tf.TypeSpec` instance describing the `TransformedVariable`s
545-
`pretransformed_input` attribute.
546-
transform_or_spec: The `bijector` passed to the `TransformedVariable`'s
547-
constructor, or `bijector._type_spec` if `bijector` is a
548-
`CompositeTensor`.
549-
dtype: `tf.DType`, `dtype` property of the `TransformedVariable`.
550-
name: `str`, name of the `TransformedVariable`.
551-
"""
552-
self._input_spec = input_spec
553-
self._transform_or_spec = transform_or_spec
554-
self._dtype = dtype
555-
self._name = name
556-
557-
self._unique_id_params = {'dtype': dtype}
558-
self._transform_is_composite = isinstance(transform_or_spec, tf.TypeSpec)
559-
560-
self._specs = {'input_spec': input_spec}
561-
if self._transform_is_composite:
562-
self._specs['transform_or_spec'] = transform_or_spec
563-
564-
@property
565-
def value_type(self):
566-
return TransformedVariable
533+
class _DeferredTensorSpecBase(object):
534+
"""Common methods for '_DeferredTensorSpec' and '_TransformedVariableSpec."""
567535

568536
@property
569537
def name(self):
@@ -637,58 +605,17 @@ def relax(value):
637605
else:
638606
return value
639607

640-
transform_or_spec = self._specs.pop(
608+
specs = self._specs.copy()
609+
transform_or_spec = specs.pop(
641610
'transform_or_spec', self.transform_or_spec)
642611
return type(self)(
643612
**tf.nest.map_structure(
644613
relax,
645-
dict(self._specs,
614+
dict(specs,
646615
transform_or_spec=transform_or_spec,
647616
**self._unique_id_params,
648617
name=self.name)))
649618

650-
def _to_components(self, value):
651-
"""Encodes `value` as a nested structure of Tensor/CompositeTensor."""
652-
components = dict(pretransformed_input=value.pretransformed_input)
653-
if isinstance(value.bijector, tf.__internal__.CompositeTensor):
654-
components['bijector'] = value.bijector
655-
return components
656-
657-
def _from_components(self, components):
658-
"""Reconstructs a value from a structure of Tensor/CompositeTensor."""
659-
bijector = components.pop('bijector', self.transform_or_spec)
660-
return TransformedVariable(
661-
**components, initial_value=None, bijector=bijector,
662-
dtype=self.dtype, name=self.name)
663-
664-
@property
665-
def _component_specs(self):
666-
"""A nested structure of TypeSpecs for the DeferredTensor's components."""
667-
specs = dict(pretransformed_input=self._input_spec)
668-
if self._transform_is_composite:
669-
specs['bijector'] = self.transform_or_spec
670-
return specs
671-
672-
def _batch(self, batch_size):
673-
"""Returns a TypeSpec representing a batch of DeferredTensors."""
674-
transform_or_spec = self._specs.pop(
675-
'transform_or_spec', self.transform_or_spec)
676-
return type(self)(
677-
self._get_batched_input_spec(batch_size),
678-
transform_or_spec=transform_or_spec,
679-
dtype=self.dtype,
680-
name=self.name)
681-
682-
def _unbatch(self):
683-
"""Returns a TypeSpec representing a single DeferredTensor."""
684-
transform_or_spec = self._specs.pop(
685-
'transform_or_spec', self.transform_or_spec)
686-
return type(self)(
687-
self._get_unbatched_input_spec(),
688-
transform_or_spec=transform_or_spec,
689-
dtype=self.dtype,
690-
name=self.name)
691-
692619
def _get_batched_input_spec(self, batch_size):
693620
"""Returns the batched `input_spec` for the given `batch_size`."""
694621
if isinstance(self._input_spec, type_spec.BatchableTypeSpec):
@@ -759,3 +686,184 @@ def __ne__(self, other):
759686
def __hash__(self):
760687
return hash(self.__get_cmp_key())
761688

689+
690+
@type_spec.register('tfp.util.DeferredTensorSpec')
691+
class _DeferredTensorSpec(_DeferredTensorSpecBase, type_spec.BatchableTypeSpec):
692+
"""`tf.TypeSpec` for `tfp.util.DeferredTensor`."""
693+
694+
__slots__ = ('_input_spec', '_transform_or_spec', '_also_track_spec',
695+
'_dtype', '_shape', '_name', '_specs', '_unique_id_params',
696+
'_transform_is_composite')
697+
698+
def __init__(self, input_spec, transform_or_spec, dtype, shape, name,
699+
also_track_spec=None):
700+
"""Initializes a new `_DeferredTensorSpec`.
701+
702+
Args:
703+
input_spec: `tf.TypeSpec` instance describing the `DeferredTensor`s
704+
`pretransformed_input` attribute.
705+
transform_or_spec: The `transform_fn` passed to the `DeferredTensor`'s
706+
constructor, or `transform_fn._type_spec` if `transform_fn` is a
707+
`CompositeTensor`.
708+
dtype: `tf.DType`, `dtype` property of the `DeferredTensor`.
709+
shape: `tf.TensorShape`, `shape` property of the `DeferredTensor`.
710+
name: `str`, name of the `DeferredTensor`.
711+
also_track_spec: Python `list` of `VariableSpec` instances describing the
712+
additional variables tracked by the `DeferredTensor`.
713+
"""
714+
self._input_spec = input_spec
715+
self._transform_or_spec = transform_or_spec
716+
self._also_track_spec = also_track_spec
717+
self._dtype = dtype
718+
self._shape = shape
719+
self._name = name
720+
721+
self._transform_is_composite = isinstance(transform_or_spec, tf.TypeSpec)
722+
self._unique_id_params = {'dtype': dtype, 'shape': shape}
723+
724+
self._specs = {'input_spec': input_spec}
725+
if self._transform_is_composite:
726+
self._specs['transform_or_spec'] = transform_or_spec
727+
if also_track_spec is not None:
728+
self._specs['also_track_spec'] = also_track_spec
729+
730+
@property
731+
def value_type(self):
732+
return DeferredTensor
733+
734+
@property
735+
def shape(self):
736+
return self._shape
737+
738+
def _to_components(self, value):
739+
"""Encodes `value` as a nested structure of Tensor/CompositeTensor."""
740+
components = dict(pretransformed_input=value.pretransformed_input)
741+
# pylint: disable=protected-access
742+
if isinstance(value._transform_fn, tf.__internal__.CompositeTensor):
743+
components['transform_fn'] = value._transform_fn
744+
if value.also_track is not None:
745+
components['also_track'] = tf.nest.flatten(
746+
tf.nest.map_structure(
747+
lambda x: x.variables if isinstance(x, tf.Module) else x,
748+
value.also_track))
749+
return components
750+
751+
def _from_components(self, components):
752+
"""Reconstructs a value from a structure of Tensor/CompositeTensor."""
753+
transform_fn = components.pop('transform_fn', self.transform_or_spec)
754+
return DeferredTensor(**components, transform_fn=transform_fn,
755+
dtype=self.dtype, shape=self.shape, name=self.name)
756+
757+
@property
758+
def _component_specs(self):
759+
"""A nested structure of TypeSpecs for the DeferredTensor's components."""
760+
specs = dict(pretransformed_input=self._input_spec)
761+
if self._transform_is_composite:
762+
specs['transform_fn'] = self.transform_or_spec
763+
if self._also_track_spec is not None:
764+
specs['also_track'] = self._also_track_spec
765+
return specs
766+
767+
def _batch(self, batch_size):
768+
"""Returns a TypeSpec representing a batch of DeferredTensors."""
769+
transform_or_spec = self._specs.get(
770+
'transform_or_spec', self.transform_or_spec)
771+
return _DeferredTensorSpec(
772+
self._get_batched_input_spec(batch_size),
773+
transform_or_spec=transform_or_spec,
774+
dtype=self.dtype,
775+
shape=(None if self.shape is None
776+
else tf.TensorShape([batch_size]).concatenate(self.shape)),
777+
name=self.name,
778+
also_track_spec=self._also_track_spec)
779+
780+
def _unbatch(self):
781+
"""Returns a TypeSpec representing a single DeferredTensor."""
782+
transform_or_spec = self._specs.get(
783+
'transform_or_spec', self.transform_or_spec)
784+
return _DeferredTensorSpec(
785+
self._get_unbatched_input_spec(),
786+
transform_or_spec=transform_or_spec,
787+
dtype=self.dtype,
788+
shape=(None if self.shape is None else self.shape[1:]),
789+
name=self.name,
790+
also_track_spec=self._also_track_spec)
791+
792+
793+
@type_spec.register('tfp.util.TransformedVariableSpec')
794+
class _TransformedVariableSpec(
795+
_DeferredTensorSpecBase, type_spec.BatchableTypeSpec):
796+
"""`tf.TypeSpec` for `tfp.util.TransformedVariable`."""
797+
798+
__slots__ = ('_input_spec', '_transform_or_spec', '_dtype', '_name', '_specs',
799+
'_unique_id_params', '_transform_is_composite')
800+
801+
def __init__(self, input_spec, transform_or_spec, dtype, name):
802+
"""Initializes a new `_TransformedVariableSpec`.
803+
804+
Args:
805+
input_spec: `tf.TypeSpec` instance describing the `TransformedVariable`s
806+
`pretransformed_input` attribute.
807+
transform_or_spec: The `bijector` passed to the `TransformedVariable`'s
808+
constructor, or `bijector._type_spec` if `bijector` is a
809+
`CompositeTensor`.
810+
dtype: `tf.DType`, `dtype` property of the `TransformedVariable`.
811+
name: `str`, name of the `TransformedVariable`.
812+
"""
813+
self._input_spec = input_spec
814+
self._transform_or_spec = transform_or_spec
815+
self._dtype = dtype
816+
self._name = name
817+
818+
self._unique_id_params = {'dtype': dtype}
819+
self._transform_is_composite = isinstance(transform_or_spec, tf.TypeSpec)
820+
821+
self._specs = {'input_spec': input_spec}
822+
if self._transform_is_composite:
823+
self._specs['transform_or_spec'] = transform_or_spec
824+
825+
@property
826+
def value_type(self):
827+
return TransformedVariable
828+
829+
def _to_components(self, value):
830+
"""Encodes `value` as a nested structure of Tensor/CompositeTensor."""
831+
components = dict(pretransformed_input=value.pretransformed_input)
832+
if isinstance(value.bijector, tf.__internal__.CompositeTensor):
833+
components['bijector'] = value.bijector
834+
return components
835+
836+
def _from_components(self, components):
837+
"""Reconstructs a value from a structure of Tensor/CompositeTensor."""
838+
bijector = components.pop('bijector', self.transform_or_spec)
839+
return TransformedVariable(
840+
**components, initial_value=None, bijector=bijector,
841+
dtype=self.dtype, name=self.name)
842+
843+
@property
844+
def _component_specs(self):
845+
"""A structure of TypeSpecs for the TransformedVariable's components."""
846+
specs = dict(pretransformed_input=self._input_spec)
847+
if self._transform_is_composite:
848+
specs['bijector'] = self.transform_or_spec
849+
return specs
850+
851+
def _batch(self, batch_size):
852+
"""Returns a TypeSpec representing a batch of TransformedVariable."""
853+
transform_or_spec = self._specs.get(
854+
'transform_or_spec', self.transform_or_spec)
855+
return _TransformedVariableSpec(
856+
self._get_batched_input_spec(batch_size),
857+
transform_or_spec=transform_or_spec,
858+
dtype=self.dtype,
859+
name=self.name)
860+
861+
def _unbatch(self):
862+
"""Returns a TypeSpec representing a single TransformedVariable."""
863+
transform_or_spec = self._specs.get(
864+
'transform_or_spec', self.transform_or_spec)
865+
return _TransformedVariableSpec(
866+
self._get_unbatched_input_spec(),
867+
transform_or_spec=transform_or_spec,
868+
dtype=self.dtype,
869+
name=self.name)

0 commit comments

Comments
 (0)