@@ -77,17 +77,15 @@ def __init__(self, encoder, tensorspec):
77
77
# These dictionaries are filled inside of the initial_state_fn and encode_fn
78
78
# methods, to be used in encode_fn and decode_fn methods, respectively.
79
79
# Decorated by tf.function, their necessary side effects are realized during
80
- # call to get_concrete_function(). Because of fixed input_signatures, these
81
- # are traced only once. See the tf.function tutorial for more details on
82
- # the tracing semantics.
80
+ # call to get_concrete_function().
83
81
state_py_structure = {}
84
82
encoded_py_structure = {}
85
83
86
84
@tf .function
87
85
def initial_state_fn ():
88
86
state = encoder .initial_state ()
89
- assert not state_py_structure # This should be traced only once.
90
- state_py_structure ['state' ] = nest .map_structure (lambda _ : None , state )
87
+ if not state_py_structure :
88
+ state_py_structure ['state' ] = nest .map_structure (lambda _ : None , state )
91
89
# Simplify the structure that needs to be manipulated by the user.
92
90
return tuple (nest .flatten (state ))
93
91
@@ -119,10 +117,10 @@ def encode_fn(x, flat_state):
119
117
flat_encoded_py_structure , flat_encoded_tf_structure = (
120
118
py_utils .split_dict_py_tf (flat_encoded_structure ))
121
119
122
- assert not encoded_py_structure # This should be traced only once.
123
- encoded_py_structure ['full' ] = nest .map_structure (lambda _ : None ,
124
- full_encoded_structure )
125
- encoded_py_structure ['flat_py' ] = flat_encoded_py_structure
120
+ if not encoded_py_structure :
121
+ encoded_py_structure ['full' ] = nest .map_structure (
122
+ lambda _ : None , full_encoded_structure )
123
+ encoded_py_structure ['flat_py' ] = flat_encoded_py_structure
126
124
return flat_encoded_tf_structure , updated_flat_state
127
125
128
126
@tf .function (input_signature = [
@@ -145,6 +143,12 @@ def decode_fn(encoded_structure):
145
143
self ._initial_state_fn = initial_state_fn
146
144
self ._encode_fn = encode_fn
147
145
self ._decode_fn = decode_fn
146
+ self ._tensorspec = tensorspec
147
+
148
+ @property
149
+ def input_tensorspec (self ):
150
+ """Returns `tf.TensorSpec` describing input expected by `SimpleEncoder`."""
151
+ return self ._tensorspec
148
152
149
153
def initial_state (self , name = None ):
150
154
"""Returns the initial state.
@@ -182,8 +186,7 @@ def encode(self, x, state=None, name=None):
182
186
"""
183
187
if state is None :
184
188
state = self .initial_state ()
185
- with tf .name_scope (name , 'simple_encoder_encode' ,
186
- [x ] + list (state )):
189
+ with tf .name_scope (name , 'simple_encoder_encode' , [x ] + list (state )):
187
190
return self ._encode_fn (x , state )
188
191
189
192
def decode (self , encoded_x , name = None ):
@@ -205,4 +208,3 @@ def decode(self, encoded_x, name=None):
205
208
"""
206
209
with tf .name_scope (name , 'simple_encoder_decode' , encoded_x .values ()):
207
210
return self ._decode_fn (encoded_x )
208
-
0 commit comments