Skip to content

Commit c47d6e5

Browse files
PraChetittensorflower-gardener
authored andcommitted
Changes the use of dict to OrderedDict.
PiperOrigin-RevId: 421306706
1 parent e12b3b3 commit c47d6e5

File tree

5 files changed

+72
-56
lines changed

5 files changed

+72
-56
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/clipping.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from __future__ import division
2424
from __future__ import print_function
2525

26+
import collections
2627
import tensorflow as tf
2728

2829
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
@@ -68,13 +69,16 @@ def decode_needs_input_shape(self):
6869

6970
def get_params(self):
7071
"""See base class."""
71-
return {self.NORM_PARAMS_KEY: self._clip_norm}, {}
72+
encode_params = collections.OrderedDict([(self.NORM_PARAMS_KEY,
73+
self._clip_norm)])
74+
decode_params = collections.OrderedDict()
75+
return encode_params, decode_params
7276

7377
def encode(self, x, encode_params):
7478
"""See base class."""
7579
clipped_x = tf.clip_by_norm(
7680
x, tf.cast(encode_params[self.NORM_PARAMS_KEY], x.dtype))
77-
return {self.ENCODED_VALUES_KEY: clipped_x}
81+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)])
7882

7983
def decode(self,
8084
encoded_tensors,
@@ -129,19 +133,19 @@ def decode_needs_input_shape(self):
129133

130134
def get_params(self):
131135
"""See base class."""
132-
params = {
133-
self.MIN_PARAMS_KEY: self._clip_value_min,
134-
self.MAX_PARAMS_KEY: self._clip_value_max
135-
}
136-
return params, {}
136+
params = collections.OrderedDict([
137+
(self.MIN_PARAMS_KEY, self._clip_value_min),
138+
(self.MAX_PARAMS_KEY, self._clip_value_max)
139+
])
140+
return params, collections.OrderedDict()
137141

138142
def encode(self, x, encode_params):
139143
"""See base class."""
140144
clipped_x = tf.clip_by_value(
141145
x,
142146
tf.cast(encode_params[self.MIN_PARAMS_KEY], x.dtype),
143147
tf.cast(encode_params[self.MAX_PARAMS_KEY], x.dtype))
144-
return {self.ENCODED_VALUES_KEY: clipped_x}
148+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)])
145149

146150
def decode(self,
147151
encoded_tensors,

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/kashin.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
2021
import numpy as np
2122
import tensorflow as tf
2223

@@ -170,12 +171,11 @@ def decode_needs_input_shape(self):
170171
def get_params(self):
171172
"""See base class."""
172173
seed = tf.random.uniform((2,), maxval=tf.int64.max, dtype=tf.int64)
173-
encode_params = {
174-
self.ETA_PARAMS_KEY: self._eta,
175-
self.DELTA_PARAMS_KEY: self._delta,
176-
self.SEED_PARAMS_KEY: seed,
177-
}
178-
decode_params = {self.SEED_PARAMS_KEY: seed}
174+
encode_params = collections.OrderedDict([(self.ETA_PARAMS_KEY, self._eta),
175+
(self.DELTA_PARAMS_KEY,
176+
self._delta),
177+
(self.SEED_PARAMS_KEY, seed)])
178+
decode_params = collections.OrderedDict([(self.SEED_PARAMS_KEY, seed)])
179179
return encode_params, decode_params
180180

181181
def encode(self, x, encode_params):
@@ -211,7 +211,8 @@ def encode(self, x, encode_params):
211211
tf.norm(x, axis=1, keepdims=True),
212212
tf.norm(kashin_coefficients, axis=1, keepdims=True))
213213

214-
return {self.ENCODED_VALUES_KEY: kashin_coefficients}
214+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY,
215+
kashin_coefficients)])
215216

216217
def decode(self,
217218
encoded_tensors,

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/misc.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
2021
import tensorflow as tf
2122

2223
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
@@ -68,7 +69,10 @@ def decode_needs_input_shape(self):
6869

6970
def get_params(self):
7071
"""See base class."""
71-
return {self.THRESHOLD_PARAMS_KEY: self._threshold}, {}
72+
encode_params = collections.OrderedDict([(self.THRESHOLD_PARAMS_KEY,
73+
self._threshold)])
74+
decode_params = collections.OrderedDict()
75+
return encode_params, decode_params
7276

7377
def encode(self, x, encode_params):
7478
"""See base class."""
@@ -77,10 +81,10 @@ def encode(self, x, encode_params):
7781
indices = tf.cast(tf.compat.v2.where(tf.abs(x) > threshold), tf.int32)
7882
non_zero_x = tf.gather_nd(x, indices)
7983
indices = tf.squeeze(indices, axis=1)
80-
return {
81-
self.ENCODED_INDICES_KEY: indices,
82-
self.ENCODED_VALUES_KEY: non_zero_x,
83-
}
84+
return collections.OrderedDict([
85+
(self.ENCODED_INDICES_KEY, indices),
86+
(self.ENCODED_VALUES_KEY, non_zero_x),
87+
])
8488

8589
def decode(self,
8690
encoded_tensors,
@@ -144,7 +148,7 @@ def decode_needs_input_shape(self):
144148

145149
def get_params(self):
146150
"""See base class."""
147-
return {}, {}
151+
return collections.OrderedDict(), collections.OrderedDict()
148152

149153
def encode(self, x, encode_params):
150154
"""See base class."""
@@ -157,9 +161,7 @@ def encode(self, x, encode_params):
157161
'Unsupported input type: %s. Support only integer types.' % x.dtype)
158162

159163
diff_x = x - tf.concat([[0], x[:-1]], 0)
160-
return {
161-
self.ENCODED_VALUES_KEY: diff_x,
162-
}
164+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, diff_x)])
163165

164166
def decode(self,
165167
encoded_tensors,

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/quantization.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
2021
import tensorflow as tf
2122

2223
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
@@ -92,7 +93,8 @@ def decode_needs_input_shape(self):
9293

9394
def get_params(self):
9495
"""See base class."""
95-
params = {self.MAX_INT_VALUE_PARAMS_KEY: 2**self._bits - 1}
96+
params = collections.OrderedDict([(self.MAX_INT_VALUE_PARAMS_KEY,
97+
2**self._bits - 1)])
9698
return params, params
9799

98100
def encode(self, x, encode_params):
@@ -117,11 +119,11 @@ def encode(self, x, encode_params):
117119

118120
# Include the random seed in the encoded tensors so that it can be used to
119121
# generate the same random sequence in the decode method.
120-
encoded_tensors = {
121-
self.ENCODED_VALUES_KEY: quantized_x,
122-
self.SEED_PARAMS_KEY: random_seed,
123-
self.MIN_MAX_VALUES_KEY: tf.stack([min_x, max_x])
124-
}
122+
encoded_tensors = collections.OrderedDict([
123+
(self.ENCODED_VALUES_KEY, quantized_x),
124+
(self.SEED_PARAMS_KEY, random_seed),
125+
(self.MIN_MAX_VALUES_KEY, tf.stack([min_x, max_x]))
126+
])
125127

126128
return encoded_tensors
127129

@@ -235,7 +237,8 @@ def decode_needs_input_shape(self):
235237

236238
def get_params(self):
237239
"""See base class."""
238-
params = {self.MAX_INT_VALUE_PARAMS_KEY: 2**self._bits - 1}
240+
params = collections.OrderedDict([(self.MAX_INT_VALUE_PARAMS_KEY,
241+
2**self._bits - 1)])
239242
return params, params
240243

241244
def encode(self, x, encode_params):
@@ -259,10 +262,10 @@ def encode(self, x, encode_params):
259262
else: # Deterministic rounding.
260263
quantized_x = tf.round(x)
261264

262-
encoded_tensors = {
263-
self.ENCODED_VALUES_KEY: quantized_x,
264-
self.MIN_MAX_VALUES_KEY: tf.stack([min_x, max_x])
265-
}
265+
encoded_tensors = collections.OrderedDict([
266+
(self.ENCODED_VALUES_KEY, quantized_x),
267+
(self.MIN_MAX_VALUES_KEY, tf.stack([min_x, max_x]))
268+
])
266269

267270
return encoded_tensors
268271

@@ -359,7 +362,8 @@ def decode_needs_input_shape(self):
359362

360363
def get_params(self):
361364
"""See base class."""
362-
params = {self.MAX_INT_VALUE_PARAMS_KEY: 2**self._bits - 1}
365+
params = collections.OrderedDict([(self.MAX_INT_VALUE_PARAMS_KEY,
366+
2**self._bits - 1)])
363367
return params, params
364368

365369
def encode(self, x, encode_params):
@@ -388,11 +392,11 @@ def encode(self, x, encode_params):
388392

389393
# Include the random seed in the encoded tensors so that it can be used to
390394
# generate the same random sequence in the decode method.
391-
encoded_tensors = {
392-
self.ENCODED_VALUES_KEY: quantized_x,
393-
self.SEED_PARAMS_KEY: random_seed,
394-
self.MIN_MAX_VALUES_KEY: tf.stack([min_x, max_x])
395-
}
395+
encoded_tensors = collections.OrderedDict([
396+
(self.ENCODED_VALUES_KEY, quantized_x),
397+
(self.SEED_PARAMS_KEY, random_seed),
398+
(self.MIN_MAX_VALUES_KEY, tf.stack([min_x, max_x]))
399+
])
396400

397401
return encoded_tensors
398402

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/stages_impl.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
2021
import numpy as np
2122
import tensorflow as tf
2223

@@ -52,12 +53,12 @@ def decode_needs_input_shape(self):
5253

5354
def get_params(self):
5455
"""See base class."""
55-
return {}, {}
56+
return collections.OrderedDict(), collections.OrderedDict()
5657

5758
def encode(self, x, encode_params):
5859
"""See base class."""
5960
del encode_params # Unused.
60-
return {self.ENCODED_VALUES_KEY: tf.identity(x)}
61+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, tf.identity(x))])
6162

6263
def decode(self,
6364
encoded_tensors,
@@ -97,12 +98,13 @@ def decode_needs_input_shape(self):
9798

9899
def get_params(self):
99100
"""See base class."""
100-
return {}, {}
101+
return collections.OrderedDict(), collections.OrderedDict()
101102

102103
def encode(self, x, encode_params):
103104
"""See base class."""
104105
del encode_params # Unused.
105-
return {self.ENCODED_VALUES_KEY: tf.reshape(x, [-1])}
106+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY,
107+
tf.reshape(x, [-1]))])
106108

107109
def decode(self,
108110
encoded_tensors,
@@ -173,10 +175,10 @@ def decode_needs_input_shape(self):
173175

174176
def get_params(self):
175177
"""See base class."""
176-
params = {
177-
self.SEED_PARAMS_KEY:
178-
tf.random.uniform((2,), maxval=tf.int64.max, dtype=tf.int64),
179-
}
178+
params = collections.OrderedDict()
179+
params[self.SEED_PARAMS_KEY] = tf.random.uniform((2,),
180+
maxval=tf.int64.max,
181+
dtype=tf.int64)
180182
return params, params
181183

182184
def encode(self, x, encode_params):
@@ -187,7 +189,7 @@ def encode(self, x, encode_params):
187189
x = x * signs
188190
x = self._pad(x)
189191
rotated_x = tf_utils.fast_walsh_hadamard_transform(x)
190-
return {self.ENCODED_VALUES_KEY: rotated_x}
192+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, rotated_x)])
191193

192194
def decode(self,
193195
encoded_tensors,
@@ -325,7 +327,8 @@ def decode_needs_input_shape(self):
325327

326328
def get_params(self):
327329
"""See base class."""
328-
params = {self.MAX_INT_VALUE_PARAMS_KEY: 2**self._bits - 1}
330+
params = collections.OrderedDict([(self.MAX_INT_VALUE_PARAMS_KEY,
331+
2**self._bits - 1)])
329332
if self._min_max is not None:
330333
# If fixed min and max is provided, expose them via params.
331334
params[self.MIN_MAX_VALUES_KEY] = self._min_max
@@ -353,7 +356,8 @@ def encode(self, x, encode_params):
353356
else: # Deterministic rounding.
354357
quantized_x = tf.round(x)
355358

356-
encoded_tensors = {self.ENCODED_VALUES_KEY: quantized_x}
359+
encoded_tensors = collections.OrderedDict([(self.ENCODED_VALUES_KEY,
360+
quantized_x)])
357361
if self.MIN_MAX_VALUES_KEY not in encode_params:
358362
encoded_tensors[self.MIN_MAX_VALUES_KEY] = tf.stack([min_x, max_x])
359363
return encoded_tensors
@@ -448,7 +452,7 @@ def decode_needs_input_shape(self):
448452

449453
def get_params(self):
450454
"""See base class."""
451-
return {}, {}
455+
return collections.OrderedDict(), collections.OrderedDict()
452456

453457
def encode(self, x, encode_params):
454458
"""See base class."""
@@ -461,10 +465,11 @@ def encode(self, x, encode_params):
461465
# If another type is provided, return a Tensor with a single value of that
462466
# type to be able to recover the type from encoded_tensors in decode method.
463467
if x.dtype == tf.float32:
464-
return {self.ENCODED_VALUES_KEY: packed_x}
468+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, packed_x)])
465469
elif x.dtype == tf.float64:
466-
return {self.ENCODED_VALUES_KEY: packed_x,
467-
self.DUMMY_TYPE_VALUES_KEY: tf.constant(0.0, dtype=tf.float64)}
470+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, packed_x),
471+
(self.DUMMY_TYPE_VALUES_KEY,
472+
tf.constant(0.0, dtype=tf.float64))])
468473
else:
469474
raise TypeError(
470475
'Unsupported packing type: %s. Supported types are tf.float32 and '

0 commit comments

Comments
 (0)