13
13
# limitations under the License.
14
14
"""Base Encoder class for encoding in the "many-to-one" case."""
15
15
16
- from __future__ import absolute_import
17
- from __future__ import division
18
- from __future__ import print_function
19
-
20
16
import collections
21
17
import tensorflow as tf
22
18
28
24
_TENSORS = 'tensors'
29
25
30
26
31
- class GatherEncoder ( object ) :
27
+ class GatherEncoder :
32
28
"""A class for a gather-like operations with encoding.
33
29
34
30
This class provides functionality for encoding in the "many-to-one" case,
@@ -198,17 +194,19 @@ def _add_to_py_values(key, value):
198
194
if key not in internal_py_values :
199
195
internal_py_values [key ] = value
200
196
201
- @tf .function
202
197
def initial_state_fn ():
203
198
"""See the `initial_state` method of this class."""
204
- state = encoder .initial_state ()
199
+ # Convert to tensor values here. If the initial state is returning
200
+ # variables, this captures the value of the variable at the moment of
201
+ # this call.
202
+ state = tf .nest .map_structure (tf .convert_to_tensor ,
203
+ encoder .initial_state ())
205
204
_add_to_structure ('state' , state )
206
205
return tuple (tf .nest .flatten (state ))
207
206
208
207
state = initial_state_fn ()
209
208
flat_state_spec = tf .nest .map_structure (tf .TensorSpec .from_tensor , state )
210
209
211
- @tf .function
212
210
def get_params_fn (flat_state ):
213
211
"""See the `get_params` method of this class."""
214
212
py_utils .assert_compatible (flat_state_spec , flat_state )
@@ -249,16 +247,15 @@ def get_params_fn(flat_state):
249
247
tuple (tf .nest .flatten (decode_before_sum_params_tf )),
250
248
tuple (tf .nest .flatten (decode_after_sum_params_tf )))
251
249
252
- encode_params , decode_before_sum_params , decode_after_sum_params = (
253
- get_params_fn (state ))
250
+ encode_params , decode_before_sum_params , decode_after_sum_params = tf . nest . map_structure (
251
+ tf . convert_to_tensor , ( get_params_fn (state ) ))
254
252
encode_params_spec = tf .nest .map_structure (tf .TensorSpec .from_tensor ,
255
253
encode_params )
256
254
decode_before_sum_params_spec = tf .nest .map_structure (
257
255
tf .TensorSpec .from_tensor , decode_before_sum_params )
258
256
decode_after_sum_params_spec = tf .nest .map_structure (
259
257
tf .TensorSpec .from_tensor , decode_after_sum_params )
260
258
261
- @tf .function
262
259
def encode_fn (x , params ):
263
260
"""See the `encode` method of this class."""
264
261
if not tensorspec .is_compatible_with (x ):
@@ -294,7 +291,6 @@ def encode_fn(x, params):
294
291
encoded_structure_spec = tf .nest .map_structure (tf .TensorSpec .from_tensor ,
295
292
encoded_structure )
296
293
297
- @tf .function
298
294
def decode_before_sum_fn (encoded_structure , params ):
299
295
"""See the `decode_before_sum` method of this class."""
300
296
py_utils .assert_compatible (encoded_structure_spec , encoded_structure )
@@ -326,7 +322,6 @@ def decode_before_sum_fn(encoded_structure, params):
326
322
part_decoded_structure_spec = tf .nest .map_structure (
327
323
tf .TensorSpec .from_tensor , part_decoded_structure )
328
324
329
- @tf .function
330
325
def decode_after_sum_fn (part_decoded_structure , params , num_summands ):
331
326
"""See the `decode_after_sum` method of this class."""
332
327
py_utils .assert_compatible (part_decoded_structure_spec ,
@@ -350,7 +345,6 @@ def decode_after_sum_fn(part_decoded_structure, params, num_summands):
350
345
decode_after_sum_params , 1 )
351
346
assert tensorspec .is_compatible_with (decoded_x )
352
347
353
- @tf .function
354
348
def update_state_fn (flat_state , state_update_tensors ):
355
349
"""See the `update_state` method of this class."""
356
350
py_utils .assert_compatible (flat_state_spec , flat_state )
0 commit comments