Skip to content

Commit 7ecc78c

Browse files
PraChetittensorflower-gardener
authored andcommitted
Changes to tensor_encoding to make code compatible with both TF 1.x and TF 2.0.
PiperOrigin-RevId: 265079527
1 parent a07b52d commit 7ecc78c

19 files changed

+118
-66
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/core_encoder.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@ def initial_state(self, name=None):
150150
keys as `self.children`, each of which maps to an object like this one,
151151
recursively.
152152
"""
153-
with tf.name_scope(name, 'encoder_initial_state'):
153+
with tf.compat.v1.name_scope(name, 'encoder_initial_state'):
154154
return self._initial_state_impl()
155155

156156
def _initial_state_impl(self):
157157
"""Implementation for the `initial_state` method."""
158158
children_state = {}
159159
for key, encoder in six.iteritems(self.children):
160-
with tf.name_scope(None, '/'.join([self.stage.name, key])):
160+
with tf.compat.v1.name_scope(None, '/'.join([self.stage.name, key])):
161161
children_state[key] = encoder._initial_state_impl() # pylint: disable=protected-access
162162
return {
163163
EncoderKeys.STATE: self.stage.initial_state(),
@@ -185,14 +185,14 @@ def update_state(self, state, state_update_tensors, name=None):
185185
recursively.
186186
"""
187187
values = tf.nest.flatten(state) + tf.nest.flatten(state_update_tensors)
188-
with tf.name_scope(name, 'encoder_update_state', values):
188+
with tf.compat.v1.name_scope(name, 'encoder_update_state', values):
189189
return self._update_state_impl(state, state_update_tensors)
190190

191191
def _update_state_impl(self, state, state_update_tensors):
192192
"""Implementation for the `update_state` method."""
193193
children_states = {}
194194
for key, encoder in six.iteritems(self.children):
195-
with tf.name_scope(None, '/'.join([self.stage.name, key])):
195+
with tf.compat.v1.name_scope(None, '/'.join([self.stage.name, key])):
196196
children_states[key] = encoder._update_state_impl( # pylint: disable=protected-access
197197
state[EncoderKeys.CHILDREN][key],
198198
state_update_tensors[EncoderKeys.CHILDREN][key])
@@ -222,7 +222,8 @@ def get_params(self, state, name=None):
222222
`self.children`, each of which maps to an object like this one,
223223
recursively.
224224
"""
225-
with tf.name_scope(name, 'encoder_get_params', tf.nest.flatten(state)):
225+
with tf.compat.v1.name_scope(name, 'encoder_get_params',
226+
tf.nest.flatten(state)):
226227
return self._get_params_impl(state)
227228

228229
def _get_params_impl(self, state):
@@ -234,7 +235,7 @@ def _get_params_impl(self, state):
234235
children_encode_params = {}
235236
children_decode_params = {}
236237
for key, encoder in six.iteritems(self.children):
237-
with tf.name_scope(None, '/'.join([self.stage.name, key])):
238+
with tf.compat.v1.name_scope(None, '/'.join([self.stage.name, key])):
238239
children_encode_params[key], children_decode_params[key] = (
239240
encoder._get_params_impl(state[EncoderKeys.CHILDREN][key])) # pylint: disable=protected-access
240241
encode_params[EncoderKeys.CHILDREN] = children_encode_params
@@ -272,8 +273,8 @@ def encode(self, x, encode_params, name=None):
272273
dictionary can be either `Tensor` objects, non-TensorFlow constants such
273274
as a `list` or numpy value, or `None`, if the shape is not needed.
274275
"""
275-
with tf.name_scope(name, 'encoder_encode',
276-
[x] + tf.nest.flatten(encode_params)):
276+
with tf.compat.v1.name_scope(name, 'encoder_encode',
277+
[x] + tf.nest.flatten(encode_params)):
277278
return self._encode_impl(x, encode_params)
278279

279280
def _encode_impl(self, x, encode_params):
@@ -290,7 +291,7 @@ def _encode_impl(self, x, encode_params):
290291
children_state_update_tensors = {}
291292
children_shapes = {}
292293
for key, encoder in six.iteritems(self.children):
293-
with tf.name_scope(None, '/'.join([self.stage.name, key])):
294+
with tf.compat.v1.name_scope(None, '/'.join([self.stage.name, key])):
294295
(encoded_tensors[key], children_state_update_tensors[key],
295296
children_shapes[key]) = encoder._encode_impl( # pylint: disable=protected-access
296297
encoded_tensors[key], encode_params[EncoderKeys.CHILDREN][key])
@@ -319,7 +320,7 @@ def decode(self, encoded_tensors, decode_params, shape, name=None):
319320
values = (
320321
tf.nest.flatten(shape) + tf.nest.flatten(decode_params) +
321322
tf.nest.flatten(encoded_tensors))
322-
with tf.name_scope(name, 'encoder_decode', values):
323+
with tf.compat.v1.name_scope(name, 'encoder_decode', values):
323324
# Calling _decode_before_sum_impl with force_decode=True will decode the
324325
# entire tree, regardless of potential commutativity with sum.
325326
return self._decode_before_sum_impl(
@@ -354,7 +355,7 @@ def decode_before_sum(self, encoded_tensors, decode_params, shape, name=None):
354355
values = (
355356
tf.nest.flatten(shape) + tf.nest.flatten(decode_params) +
356357
tf.nest.flatten(encoded_tensors))
357-
with tf.name_scope(name, 'encoder_decode_before_sum', values):
358+
with tf.compat.v1.name_scope(name, 'encoder_decode_before_sum', values):
358359
return self._decode_before_sum_impl(
359360
encoded_tensors, decode_params, shape, force_decode=False)
360361

@@ -383,7 +384,7 @@ def _decode_before_sum_impl(self, encoded_tensors, decode_params, shape,
383384
force_decode |= not self.stage.commutes_with_sum
384385
for key, value in six.iteritems(encoded_tensors):
385386
if key in self.children:
386-
with tf.name_scope(None, '/'.join([self.stage.name, key])):
387+
with tf.compat.v1.name_scope(None, '/'.join([self.stage.name, key])):
387388
temp_encoded_tensors[key] = (
388389
self.children[key]._decode_before_sum_impl( # pylint: disable=protected-access
389390
value,
@@ -435,7 +436,7 @@ def decode_after_sum(self,
435436
tf.nest.flatten(shape) + tf.nest.flatten(decode_params) +
436437
tf.nest.flatten(encoded_tensors) + [num_summands])
437438
num_summands = tf.convert_to_tensor(num_summands)
438-
with tf.name_scope(name, 'encoder_decode_after_sum', values):
439+
with tf.compat.v1.name_scope(name, 'encoder_decode_after_sum', values):
439440
return self._decode_after_sum_impl(encoded_tensors, decode_params,
440441
num_summands, shape)
441442

@@ -450,7 +451,7 @@ def _decode_after_sum_impl(self, encoded_tensors, decode_params, num_summands,
450451
temp_encoded_tensors = {}
451452
for key, value in six.iteritems(encoded_tensors):
452453
if key in self.children:
453-
with tf.name_scope(None, '/'.join([self.stage.name, key])):
454+
with tf.compat.v1.name_scope(None, '/'.join([self.stage.name, key])):
454455
temp_encoded_tensors[key] = self.children[key]._decode_after_sum_impl( # pylint: disable=protected-access
455456
value, decode_params[EncoderKeys.CHILDREN][key], num_summands,
456457
shape[EncoderKeys.CHILDREN][key])

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/core_encoder_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
AN_NORM_UPDATE = test_utils.AdaptiveNormalizeEncodingStage.NORM_STATE_UPDATE_KEY
5252

5353

54+
if tf.executing_eagerly():
55+
tf.compat.v1.disable_eager_execution()
56+
57+
5458
class EncoderTest(tf.test.TestCase):
5559

5660
def test_correct_structure(self):

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/encoding_stage.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def _tf_style_initial_state(initial_state_fn):
631631

632632
def actual_initial_state_fn(self, name=None):
633633
"""Modified `initial_state` method."""
634-
with tf.name_scope(name, self.name + INITIAL_STATE_SCOPE_SUFFIX):
634+
with tf.compat.v1.name_scope(name, self.name + INITIAL_STATE_SCOPE_SUFFIX):
635635
return initial_state_fn(self, name=name)
636636

637637
return actual_initial_state_fn
@@ -643,7 +643,8 @@ def _tf_style_update_state(update_state_fn):
643643
def actual_initial_state_fn(self, state, state_update_tensors, name=None):
644644
"""Modified `update_state` method."""
645645
values = list(state.values()) + list(state_update_tensors.values())
646-
with tf.name_scope(name, self.name + UPDATE_STATE_SCOPE_SUFFIX, values):
646+
with tf.compat.v1.name_scope(name, self.name + UPDATE_STATE_SCOPE_SUFFIX,
647+
values):
647648
state = tf.nest.map_structure(tf.convert_to_tensor, state)
648649
state_update_tensors = tf.nest.map_structure(tf.convert_to_tensor,
649650
state_update_tensors)
@@ -657,7 +658,7 @@ def _tf_style_get_params(get_params_fn):
657658

658659
def actual_get_params_fn(self, name=None):
659660
"""Modified `get_params` method."""
660-
with tf.name_scope(name, self.name + GET_PARAMS_SCOPE_SUFFIX):
661+
with tf.compat.v1.name_scope(name, self.name + GET_PARAMS_SCOPE_SUFFIX):
661662
return get_params_fn(self, name=name)
662663

663664
return actual_get_params_fn
@@ -668,8 +669,8 @@ def _tf_style_adaptive_get_params(get_params_fn):
668669

669670
def actual_get_params_fn(self, state, name=None):
670671
"""Modified `get_params` method."""
671-
with tf.name_scope(name, self.name + GET_PARAMS_SCOPE_SUFFIX,
672-
state.values()):
672+
with tf.compat.v1.name_scope(name, self.name + GET_PARAMS_SCOPE_SUFFIX,
673+
state.values()):
673674
state = tf.nest.map_structure(tf.convert_to_tensor, state)
674675
return get_params_fn(self, state, name=name)
675676

@@ -682,7 +683,8 @@ def _tf_style_encode(encode_fn):
682683
def actual_encode_fn(self, x, encode_params, name=None):
683684
"""Modified `encode` method."""
684685
values = list(encode_params.values()) + [x]
685-
with tf.variable_scope(name, self.name + ENCODE_SCOPE_SUFFIX, values):
686+
with tf.compat.v1.variable_scope(name, self.name + ENCODE_SCOPE_SUFFIX,
687+
values):
686688
x = tf.convert_to_tensor(x)
687689
encode_params = tf.nest.map_structure(tf.convert_to_tensor, encode_params)
688690
return encode_fn(self, x, encode_params, name=name)
@@ -701,9 +703,10 @@ def actual_decode_fn(self,
701703
name=None):
702704
"""Modified `decode` method."""
703705
values = list(encoded_tensors.values()) + list(decode_params.values())
704-
with tf.variable_scope(name, self.name + DECODE_SCOPE_SUFFIX, values):
706+
with tf.compat.v1.variable_scope(name, self.name + DECODE_SCOPE_SUFFIX,
707+
values):
705708
encoded_tensors = tf.nest.map_structure(tf.convert_to_tensor,
706-
encoded_tensors)
709+
encoded_tensors)
707710
decode_params = tf.nest.map_structure(tf.convert_to_tensor, decode_params)
708711
if shape is not None:
709712
shape = tf.convert_to_tensor(shape)

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/encoding_stage_test.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
2626

2727

28+
if tf.executing_eagerly():
29+
tf.compat.v1.disable_eager_execution()
30+
31+
2832
class TFStyleEncodeDecodeTest(tf.test.TestCase, parameterized.TestCase):
2933
"""Tests for `_tf_style_*` decorators.
3034
@@ -95,7 +99,7 @@ def test_initial_state_decorator(self, name):
9599
initial_state = self.evaluate(test_initial_state_fn(stage, name))
96100

97101
# The graph should contain a single node.
98-
graph = tf.get_default_graph()
102+
graph = tf.compat.v1.get_default_graph()
99103
self.assertLen(graph.as_graph_def().node, 1)
100104
if name is not None:
101105
self._assert_all_graph_nodes_in_name_scope(graph, name)
@@ -117,7 +121,7 @@ def test_update_state_decorator(self, name):
117121

118122
# The graph should contain three nodes. Two for the constants created, and
119123
# one for their addition.
120-
graph = tf.get_default_graph()
124+
graph = tf.compat.v1.get_default_graph()
121125
self.assertLen(graph.as_graph_def().node, 3)
122126
if name is not None:
123127
self._assert_all_graph_nodes_in_name_scope(graph, name)
@@ -136,7 +140,7 @@ def test_get_params_decorator(self, name):
136140
test_get_params_fn(stage, name))
137141

138142
# The graph should contain a single node.
139-
graph = tf.get_default_graph()
143+
graph = tf.compat.v1.get_default_graph()
140144
self.assertLen(graph.as_graph_def().node, 1)
141145
if name is not None:
142146
self._assert_all_graph_nodes_in_name_scope(graph, name)
@@ -158,7 +162,7 @@ def test_adaptive_get_params_decorator(self, name):
158162

159163
# The graph should contain three nodes. Two for the constants created, and
160164
# one for the multiplication to create the params.
161-
graph = tf.get_default_graph()
165+
graph = tf.compat.v1.get_default_graph()
162166
self.assertLen(graph.as_graph_def().node, 3)
163167
if name is not None:
164168
self._assert_all_graph_nodes_in_name_scope(graph, name)
@@ -178,7 +182,7 @@ def test_encode_decorator(self, name):
178182

179183
# The graph should contain three nodes. The two above Python constants
180184
# converted to a Tensor object, and the resulting sum.
181-
graph = tf.get_default_graph()
185+
graph = tf.compat.v1.get_default_graph()
182186
self.assertLen(graph.as_graph_def().node, 3)
183187
if name is not None:
184188
self._assert_all_graph_nodes_in_name_scope(graph, name)
@@ -215,7 +219,7 @@ def test_decode_decorator(self, name):
215219

216220
# The graph should contain six nodes. The four above Python constants
217221
# converted to a Tensor object, the subtraction, and the final reshape.
218-
graph = tf.get_default_graph()
222+
graph = tf.compat.v1.get_default_graph()
219223
self.assertLen(graph.as_graph_def().node, 6)
220224
if name is not None:
221225
self._assert_all_graph_nodes_in_name_scope(graph, name)
@@ -246,8 +250,8 @@ class NoneStateAdaptiveEncodingStageTest(tf.test.TestCase,
246250

247251
def test_as_adaptive_encoding_stage(self):
248252
"""Tests correctness of the wrapped encoding stage."""
249-
a_var = tf.get_variable('a', initializer=2.0)
250-
b_var = tf.get_variable('b', initializer=3.0)
253+
a_var = tf.compat.v1.get_variable('a', initializer=2.0)
254+
b_var = tf.compat.v1.get_variable('b', initializer=3.0)
251255
stage = test_utils.SimpleLinearEncodingStage(a_var, b_var)
252256
wrapped_stage = encoding_stage.as_adaptive_encoding_stage(stage)
253257
self.assertIsInstance(wrapped_stage,
@@ -279,7 +283,7 @@ def test_as_adaptive_encoding_stage(self):
279283
self.assertEqual(stage.decode_needs_input_shape,
280284
wrapped_stage.decode_needs_input_shape)
281285

282-
self.evaluate(tf.global_variables_initializer())
286+
self.evaluate(tf.compat.v1.global_variables_initializer())
283287
test_data = test_utils.TestData(*self.evaluate([x, encoded_x, decoded_x]))
284288
self.assertEqual(2.0, test_data.x)
285289
self.assertEqual(

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def initial_state(self, name=None):
390390
Returns:
391391
A tuple of `Tensor` values, representing the initial state.
392392
"""
393-
with tf.name_scope(name, 'gather_encoder_initial_state'):
393+
with tf.compat.v1.name_scope(name, 'gather_encoder_initial_state'):
394394
return self._initial_state_fn()
395395

396396
def get_params(self, state=None, name=None):
@@ -417,7 +417,8 @@ def get_params(self, state=None, name=None):
417417
"""
418418
if state is None:
419419
state = self.initial_state()
420-
with tf.name_scope(name, 'gather_encoder_get_params', list(state)):
420+
with tf.compat.v1.name_scope(name, 'gather_encoder_get_params',
421+
list(state)):
421422
state = tf.nest.map_structure(tf.convert_to_tensor, state)
422423
return self._get_params_fn(state)
423424

@@ -445,7 +446,7 @@ def encode(self, x, encode_params, name=None):
445446
`get_params` method.
446447
"""
447448
values = [x] + list(encode_params)
448-
with tf.name_scope(name, 'gather_encoder_encode', values):
449+
with tf.compat.v1.name_scope(name, 'gather_encoder_encode', values):
449450
x = tf.convert_to_tensor(x)
450451
encode_params = tf.nest.map_structure(tf.convert_to_tensor, encode_params)
451452
return self._encode_fn(x, encode_params)
@@ -476,7 +477,8 @@ def decode_before_sum(self, encoded_x, decode_before_sum_params, name=None):
476477
`get_params` method.
477478
"""
478479
values = list(encoded_x.values()) + list(decode_before_sum_params)
479-
with tf.name_scope(name, 'gather_encoder_decode_before_sum', values):
480+
with tf.compat.v1.name_scope(name, 'gather_encoder_decode_before_sum',
481+
values):
480482
encoded_x = tf.nest.map_structure(tf.convert_to_tensor, encoded_x)
481483
decode_before_sum_params = tf.nest.map_structure(
482484
tf.convert_to_tensor, decode_before_sum_params)
@@ -513,7 +515,8 @@ def decode_after_sum(self,
513515
values = list(part_decoded_x.values()) if isinstance(
514516
part_decoded_x, dict) else [part_decoded_x]
515517
values = (values + list(decode_after_sum_params) + [num_summands])
516-
with tf.name_scope(name, 'gather_encoder_decode_after_sum', values):
518+
with tf.compat.v1.name_scope(name, 'gather_encoder_decode_after_sum',
519+
values):
517520
part_decoded_x = tf.nest.map_structure(tf.convert_to_tensor,
518521
part_decoded_x)
519522
decode_after_sum_params = tf.nest.map_structure(tf.convert_to_tensor,
@@ -545,7 +548,7 @@ def update_state(self, state, state_update_tensors, name=None):
545548
return value of the `initial_state` method.
546549
"""
547550
values = list(state) + list(state_update_tensors)
548-
with tf.name_scope(name, 'gather_encoder_update_state', values):
551+
with tf.compat.v1.name_scope(name, 'gather_encoder_update_state', values):
549552
state = tf.nest.map_structure(tf.convert_to_tensor, state)
550553
state_update_tensors = tf.nest.map_structure(tf.convert_to_tensor,
551554
state_update_tensors)

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def initial_state(self, name=None):
161161
Returns:
162162
A tuple of `Tensor` values, representing the initial state.
163163
"""
164-
with tf.name_scope(name, 'simple_encoder_initial_state'):
164+
with tf.compat.v1.name_scope(name, 'simple_encoder_initial_state'):
165165
return self._initial_state_fn()
166166

167167
def encode(self, x, state=None, name=None):
@@ -188,7 +188,8 @@ def encode(self, x, state=None, name=None):
188188
"""
189189
if state is None:
190190
state = self.initial_state()
191-
with tf.name_scope(name, 'simple_encoder_encode', [x] + list(state)):
191+
with tf.compat.v1.name_scope(name, 'simple_encoder_encode',
192+
[x] + list(state)):
192193
return self._encode_fn(x, state)
193194

194195
def decode(self, encoded_x, name=None):
@@ -208,5 +209,6 @@ def decode(self, encoded_x, name=None):
208209
If `encoded_x` is not of the same structure as returned by the `encode`
209210
method.
210211
"""
211-
with tf.name_scope(name, 'simple_encoder_decode', encoded_x.values()):
212+
with tf.compat.v1.name_scope(name, 'simple_encoder_decode',
213+
encoded_x.values()):
212214
return self._decode_fn(encoded_x)

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/clipping_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
2727

2828

29+
if tf.executing_eagerly():
30+
tf.compat.v1.disable_eager_execution()
31+
32+
2933
class ClipByNormEncodingStageTest(test_utils.BaseEncodingStageTest):
3034

3135
def default_encoding_stage(self):

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/kashin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def encode(self, x, encode_params):
206206
# If there is clipping in the last iteration, this can result in
207207
# biased representation of smaller magnitude. We compensate for this
208208
# by scaling such that the norm is preserved.
209-
kashin_coefficients *= tf.div_no_nan(
209+
kashin_coefficients *= tf.compat.v1.div_no_nan(
210210
tf.norm(x, axis=1, keepdims=True),
211211
tf.norm(kashin_coefficients, axis=1, keepdims=True))
212212

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/kashin_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
2727

2828

29+
if tf.executing_eagerly():
30+
tf.compat.v1.disable_eager_execution()
31+
32+
2933
class KashinHadamardEncodingStageTest(test_utils.BaseEncodingStageTest):
3034

3135
def default_encoding_stage(self):

0 commit comments

Comments
 (0)