Skip to content

Commit 057e721

Browse files
Remove tf.function to prevent frequent amounts of retracing.
Outer level scopes should apply `tf.function` if desired. PiperOrigin-RevId: 447591780
1 parent e538361 commit 057e721

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
# limitations under the License.
1414
"""Base Encoder class for encoding in the "many-to-one" case."""
1515

16-
from __future__ import absolute_import
17-
from __future__ import division
18-
from __future__ import print_function
19-
2016
import collections
2117
import tensorflow as tf
2218

@@ -28,7 +24,7 @@
2824
_TENSORS = 'tensors'
2925

3026

31-
class GatherEncoder(object):
27+
class GatherEncoder:
3228
"""A class for a gather-like operations with encoding.
3329
3430
This class provides functionality for encoding in the "many-to-one" case,
@@ -198,17 +194,19 @@ def _add_to_py_values(key, value):
198194
if key not in internal_py_values:
199195
internal_py_values[key] = value
200196

201-
@tf.function
202197
def initial_state_fn():
203198
"""See the `initial_state` method of this class."""
204-
state = encoder.initial_state()
199+
# Convert to tensor values here. If the initial state is returning
200+
# variables, this captures the value of the variable at the moment of
201+
# this call.
202+
state = tf.nest.map_structure(tf.convert_to_tensor,
203+
encoder.initial_state())
205204
_add_to_structure('state', state)
206205
return tuple(tf.nest.flatten(state))
207206

208207
state = initial_state_fn()
209208
flat_state_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)
210209

211-
@tf.function
212210
def get_params_fn(flat_state):
213211
"""See the `get_params` method of this class."""
214212
py_utils.assert_compatible(flat_state_spec, flat_state)
@@ -249,16 +247,15 @@ def get_params_fn(flat_state):
249247
tuple(tf.nest.flatten(decode_before_sum_params_tf)),
250248
tuple(tf.nest.flatten(decode_after_sum_params_tf)))
251249

252-
encode_params, decode_before_sum_params, decode_after_sum_params = (
253-
get_params_fn(state))
250+
encode_params, decode_before_sum_params, decode_after_sum_params = tf.nest.map_structure(
251+
tf.convert_to_tensor, (get_params_fn(state)))
254252
encode_params_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor,
255253
encode_params)
256254
decode_before_sum_params_spec = tf.nest.map_structure(
257255
tf.TensorSpec.from_tensor, decode_before_sum_params)
258256
decode_after_sum_params_spec = tf.nest.map_structure(
259257
tf.TensorSpec.from_tensor, decode_after_sum_params)
260258

261-
@tf.function
262259
def encode_fn(x, params):
263260
"""See the `encode` method of this class."""
264261
if not tensorspec.is_compatible_with(x):
@@ -294,7 +291,6 @@ def encode_fn(x, params):
294291
encoded_structure_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor,
295292
encoded_structure)
296293

297-
@tf.function
298294
def decode_before_sum_fn(encoded_structure, params):
299295
"""See the `decode_before_sum` method of this class."""
300296
py_utils.assert_compatible(encoded_structure_spec, encoded_structure)
@@ -326,7 +322,6 @@ def decode_before_sum_fn(encoded_structure, params):
326322
part_decoded_structure_spec = tf.nest.map_structure(
327323
tf.TensorSpec.from_tensor, part_decoded_structure)
328324

329-
@tf.function
330325
def decode_after_sum_fn(part_decoded_structure, params, num_summands):
331326
"""See the `decode_after_sum` method of this class."""
332327
py_utils.assert_compatible(part_decoded_structure_spec,
@@ -350,7 +345,6 @@ def decode_after_sum_fn(part_decoded_structure, params, num_summands):
350345
decode_after_sum_params, 1)
351346
assert tensorspec.is_compatible_with(decoded_x)
352347

353-
@tf.function
354348
def update_state_fn(flat_state, state_update_tensors):
355349
"""See the `update_state` method of this class."""
356350
py_utils.assert_compatible(flat_state_spec, flat_state)

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from __future__ import absolute_import
16-
from __future__ import division
17-
from __future__ import print_function
18-
1915
import collections
2016

2117
from absl.testing import parameterized
@@ -136,8 +132,8 @@ def test_python_constants_not_exposed(self):
136132
test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(
137133
test_utils.SimpleLinearEncodingStage(2.0, 3.0),
138134
SL_VALS).make(), tensorspec)
139-
a_var = tf.compat.v1.get_variable('a_var', initializer=2.0)
140-
b_var = tf.compat.v1.get_variable('b_var', initializer=3.0)
135+
a_var = tf.Variable(2.0, name='a_var')
136+
b_var = tf.Variable(3.0, name='b_var')
141137
encoder_tf = gather_encoder.GatherEncoder.from_encoder(
142138
core_encoder.EncoderComposer(
143139
test_utils.SimpleLinearEncodingStage(a_var, b_var)).add_parent(

0 commit comments

Comments
 (0)