Skip to content

Commit 05e8783

Browse files
PraChetittensorflower-gardener
authored andcommitted
Changes the use of dict to OrderedDict in gather_encoder.py and simple_encoder.py.
PiperOrigin-RevId: 421564691
1 parent c47d6e5 commit 05e8783

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
2021
import tensorflow as tf
2122

2223
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import core_encoder
@@ -147,7 +148,6 @@ def from_encoder(cls, encoder, tensorspec):
147148
if not tensorspec.shape.is_fully_defined():
148149
raise TypeError('The shape of provided tensorspec must be fully defined.')
149150

150-
tensorspec = tensorspec
151151
commuting_structure = encoder.commuting_structure
152152
state_update_aggregation_modes = tf.nest.flatten(
153153
encoder.state_update_aggregation_modes)
@@ -187,8 +187,8 @@ def from_encoder(cls, encoder, tensorspec):
187187
# of the tensor_encoding tool do not even need to be aware of it. This
188188
# argument is well supported for instance in the book of John Ousterhout,
189189
# "A Philosophy of Software Design".
190-
internal_structure = {}
191-
internal_py_values = {}
190+
internal_structure = collections.OrderedDict()
191+
internal_py_values = collections.OrderedDict()
192192

193193
def _add_to_structure(key, value):
194194
if key not in internal_structure:
@@ -226,10 +226,10 @@ def get_params_fn(flat_state):
226226
_, input_shapes_after_sum = (
227227
core_encoder.split_shapes_by_commuting_structure(
228228
input_shapes, commuting_structure))
229-
decode_after_sum_params = {
230-
_PARAMS: decode_after_sum_params,
231-
_SHAPES: input_shapes_after_sum
232-
}
229+
decode_after_sum_params = collections.OrderedDict([
230+
(_PARAMS, decode_after_sum_params),
231+
(_SHAPES, input_shapes_after_sum),
232+
])
233233

234234
encode_params_py, encode_params_tf = py_utils.split_dict_py_tf(
235235
encode_params)
@@ -274,18 +274,18 @@ def encode_fn(x, params):
274274
core_encoder.split_shapes_by_commuting_structure(
275275
input_shapes, commuting_structure))
276276

277-
encoded_structure = {
278-
_TENSORS: encoded_x,
279-
_SHAPES: input_shapes_before_sum
280-
}
277+
encoded_structure = collections.OrderedDict([
278+
(_TENSORS, encoded_x),
279+
(_SHAPES, input_shapes_before_sum),
280+
])
281281
encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
282282
encoded_structure)
283283

284284
_add_to_structure('encoded_structure', encoded_structure_tf)
285285
_add_to_structure('state_update_tensors', state_update_tensors)
286286
_add_to_py_values('encoded_structure', encoded_structure_py)
287287

288-
return (dict(
288+
return (collections.OrderedDict(
289289
py_utils.flatten_with_joined_string_paths(encoded_structure_tf)),
290290
tuple(tf.nest.flatten(state_update_tensors)))
291291

@@ -316,7 +316,7 @@ def decode_before_sum_fn(encoded_structure, params):
316316

317317
_add_to_structure('part_decoded_structure', part_decoded_structure)
318318
if isinstance(part_decoded_structure, dict):
319-
return dict(
319+
return collections.OrderedDict(
320320
py_utils.flatten_with_joined_string_paths(part_decoded_structure))
321321
else:
322322
return part_decoded_structure

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
2021
import tensorflow as tf
2122

2223
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import core_encoder
@@ -79,8 +80,8 @@ def __init__(self, encoder, tensorspec):
7980
# methods, to be used in encode_fn and decode_fn methods, respectively.
8081
# Decorated by tf.function, their necessary side effects are realized during
8182
# call to get_concrete_function().
82-
state_py_structure = {}
83-
encoded_py_structure = {}
83+
state_py_structure = collections.OrderedDict()
84+
encoded_py_structure = collections.OrderedDict()
8485

8586
@tf.function
8687
def initial_state_fn():
@@ -108,12 +109,12 @@ def encode_fn(x, flat_state):
108109
# The following code converts the nested structres necessary for the
109110
# underlying encoder, to a single flat dictionary, which is simpler to
110111
# manipulate by the users of SimpleEncoder.
111-
full_encoded_structure = {
112-
_TENSORS: encoded_x,
113-
_PARAMS: decode_params,
114-
_SHAPES: input_shapes
115-
}
116-
flat_encoded_structure = dict(
112+
full_encoded_structure = collections.OrderedDict([
113+
(_TENSORS, encoded_x),
114+
(_PARAMS, decode_params),
115+
(_SHAPES, input_shapes),
116+
])
117+
flat_encoded_structure = collections.OrderedDict(
117118
py_utils.flatten_with_joined_string_paths(full_encoded_structure))
118119
flat_encoded_py_structure, flat_encoded_tf_structure = (
119120
py_utils.split_dict_py_tf(flat_encoded_structure))

0 commit comments

Comments
 (0)