File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
tensorflow_probability/python/util Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -44,6 +44,7 @@ multi_substrate_py_library(
44
44
deps = [
45
45
# numpy dep,
46
46
# tensorflow dep,
47
+ "//tensorflow_probability/python/internal:auto_composite_tensor" ,
47
48
"//tensorflow_probability/python/internal:dtype_util" ,
48
49
"//tensorflow_probability/python/internal:name_util" ,
49
50
"//tensorflow_probability/python/internal:tensor_util" ,
Original file line number Diff line number Diff line change 23
23
import six
24
24
25
25
import tensorflow .compat .v2 as tf
26
+ from tensorflow_probability .python .internal import auto_composite_tensor
26
27
from tensorflow_probability .python .internal import dtype_util
27
28
from tensorflow_probability .python .internal import name_util
28
29
from tensorflow_probability .python .internal import tensor_util
@@ -734,7 +735,7 @@ def __hash__(self):
734
735
return hash (self .__get_cmp_key ())
735
736
736
737
737
- @type_spec . register ('tfp.util.DeferredTensorSpec' )
738
+ @auto_composite_tensor . type_spec_register ('tfp.util.DeferredTensorSpec' )
738
739
class _DeferredTensorSpec (_DeferredTensorSpecBase , type_spec .BatchableTypeSpec ):
739
740
"""`tf.TypeSpec` for `tfp.util.DeferredTensor`."""
740
741
@@ -837,7 +838,7 @@ def _unbatch(self):
837
838
also_track_spec = self ._also_track_spec )
838
839
839
840
840
- @type_spec . register ('tfp.util.TransformedVariableSpec' )
841
+ @auto_composite_tensor . type_spec_register ('tfp.util.TransformedVariableSpec' )
841
842
class _TransformedVariableSpec (
842
843
_DeferredTensorSpecBase , type_spec .BatchableTypeSpec ):
843
844
"""`tf.TypeSpec` for `tfp.util.TransformedVariable`."""
You can’t perform that action at this time.
0 commit comments