Skip to content

Commit 8fa76c9

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Register _DeferredTensorSpec and _TransformedVariableSpec with auto_composite_tensor.type_spec_register to avoid errors in the type spec registry when reloading TFP in iPython.
PiperOrigin-RevId: 379608557
1 parent 698ded5 commit 8fa76c9

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tensorflow_probability/python/util/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ multi_substrate_py_library(
4444
deps = [
4545
# numpy dep,
4646
# tensorflow dep,
47+
"//tensorflow_probability/python/internal:auto_composite_tensor",
4748
"//tensorflow_probability/python/internal:dtype_util",
4849
"//tensorflow_probability/python/internal:name_util",
4950
"//tensorflow_probability/python/internal:tensor_util",

tensorflow_probability/python/util/deferred_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import six
2424

2525
import tensorflow.compat.v2 as tf
26+
from tensorflow_probability.python.internal import auto_composite_tensor
2627
from tensorflow_probability.python.internal import dtype_util
2728
from tensorflow_probability.python.internal import name_util
2829
from tensorflow_probability.python.internal import tensor_util
@@ -734,7 +735,7 @@ def __hash__(self):
734735
return hash(self.__get_cmp_key())
735736

736737

737-
@type_spec.register('tfp.util.DeferredTensorSpec')
738+
@auto_composite_tensor.type_spec_register('tfp.util.DeferredTensorSpec')
738739
class _DeferredTensorSpec(_DeferredTensorSpecBase, type_spec.BatchableTypeSpec):
739740
"""`tf.TypeSpec` for `tfp.util.DeferredTensor`."""
740741

@@ -837,7 +838,7 @@ def _unbatch(self):
837838
also_track_spec=self._also_track_spec)
838839

839840

840-
@type_spec.register('tfp.util.TransformedVariableSpec')
841+
@auto_composite_tensor.type_spec_register('tfp.util.TransformedVariableSpec')
841842
class _TransformedVariableSpec(
842843
_DeferredTensorSpecBase, type_spec.BatchableTypeSpec):
843844
"""`tf.TypeSpec` for `tfp.util.TransformedVariable`."""

0 commit comments

Comments
 (0)