Skip to content

Commit d85b2ab

Browse files
PraChetittensorflower-gardener
authored andcommitted
Implements input_tensorspec for SimpleEncoder.
This also changes detail in the use of tf.function, as retracing can be triggered multiple times in the context of different graphs, breaking the class in certain scenarios. PiperOrigin-RevId: 255297843
1 parent c77f4c2 commit d85b2ab

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,15 @@ def __init__(self, encoder, tensorspec):
7777
# These dictionaries are filled inside of the initial_state_fn and encode_fn
7878
# methods, to be used in encode_fn and decode_fn methods, respectively.
7979
# Decorated by tf.function, their necessary side effects are realized during
80-
# call to get_concrete_function(). Because of fixed input_signatures, these
81-
# are traced only once. See the tf.function tutorial for more details on
82-
# the tracing semantics.
80+
# call to get_concrete_function().
8381
state_py_structure = {}
8482
encoded_py_structure = {}
8583

8684
@tf.function
8785
def initial_state_fn():
8886
state = encoder.initial_state()
89-
assert not state_py_structure # This should be traced only once.
90-
state_py_structure['state'] = nest.map_structure(lambda _: None, state)
87+
if not state_py_structure:
88+
state_py_structure['state'] = nest.map_structure(lambda _: None, state)
9189
# Simplify the structure that needs to be manipulated by the user.
9290
return tuple(nest.flatten(state))
9391

@@ -119,10 +117,10 @@ def encode_fn(x, flat_state):
119117
flat_encoded_py_structure, flat_encoded_tf_structure = (
120118
py_utils.split_dict_py_tf(flat_encoded_structure))
121119

122-
assert not encoded_py_structure # This should be traced only once.
123-
encoded_py_structure['full'] = nest.map_structure(lambda _: None,
124-
full_encoded_structure)
125-
encoded_py_structure['flat_py'] = flat_encoded_py_structure
120+
if not encoded_py_structure:
121+
encoded_py_structure['full'] = nest.map_structure(
122+
lambda _: None, full_encoded_structure)
123+
encoded_py_structure['flat_py'] = flat_encoded_py_structure
126124
return flat_encoded_tf_structure, updated_flat_state
127125

128126
@tf.function(input_signature=[
@@ -145,6 +143,12 @@ def decode_fn(encoded_structure):
145143
self._initial_state_fn = initial_state_fn
146144
self._encode_fn = encode_fn
147145
self._decode_fn = decode_fn
146+
self._tensorspec = tensorspec
147+
148+
@property
149+
def input_tensorspec(self):
150+
"""Returns `tf.TensorSpec` describing input expected by `SimpleEncoder`."""
151+
return self._tensorspec
148152

149153
def initial_state(self, name=None):
150154
"""Returns the initial state.
@@ -182,8 +186,7 @@ def encode(self, x, state=None, name=None):
182186
"""
183187
if state is None:
184188
state = self.initial_state()
185-
with tf.name_scope(name, 'simple_encoder_encode',
186-
[x] + list(state)):
189+
with tf.name_scope(name, 'simple_encoder_encode', [x] + list(state)):
187190
return self._encode_fn(x, state)
188191

189192
def decode(self, encoded_x, name=None):
@@ -205,4 +208,3 @@ def decode(self, encoded_x, name=None):
205208
"""
206209
with tf.name_scope(name, 'simple_encoder_decode', encoded_x.values()):
207210
return self._decode_fn(encoded_x)
208-

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,14 @@ def test_input_signature_enforced(self):
202202
bad_encoded_x.update({'x': x})
203203
encoder.decode(bad_encoded_x)
204204

205+
def test_input_tensorspec(self):
206+
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
207+
encoder = simple_encoder.SimpleEncoder(
208+
core_encoder.EncoderComposer(
209+
test_utils.PlusOneOverNEncodingStage()).make(),
210+
tf.TensorSpec.from_tensor(x))
211+
self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
212+
205213
@parameterized.parameters([1.0, 'str', object])
206214
def test_not_an_encoder_raises(self, not_an_encoder):
207215
"""Tests invalid encoder argument."""

0 commit comments

Comments
 (0)