Skip to content

Commit e2879ec

Browse files
nkovela1copybara-github
authored andcommitted
Creates non-breaking changes where necessary in preparation for switching all of Keras to new serialization format.
PiperOrigin-RevId: 509911382 Change-Id: Idc317015c81814169d4862cfa3149fa8203897c4
1 parent f629172 commit e2879ec

File tree

4 files changed

+27
-23
lines changed

4 files changed

+27
-23
lines changed

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ def test_compression_works_after_serialization(self):
148148
noisy = uniform_noise.NoisyNormal(loc=.5, scale=8.)
149149
em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
150150
self.assertIsNot(em._quantization_offset, None)
151-
json = tf.keras.utils.serialize_keras_object(em)
151+
json = tf.keras.utils.legacy.serialize_keras_object(em)
152152
weights = em.get_weights()
153153
x = noisy.base.sample([100])
154154
x_quantized = em.quantize(x)
155155
x_compressed = em.compress(x)
156-
em = tf.keras.utils.deserialize_keras_object(json)
156+
em = tf.keras.utils.legacy.deserialize_keras_object(json)
157157
em.set_weights(weights)
158158
self.assertAllEqual(em.compress(x), x_compressed)
159159
self.assertAllEqual(em.decompress(x_compressed, [100]), x_quantized)
@@ -162,12 +162,12 @@ def test_compression_works_after_serialization_no_offset(self):
162162
noisy = uniform_noise.NoisyNormal(loc=0, scale=5.)
163163
em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
164164
self.assertIs(em._quantization_offset, None)
165-
json = tf.keras.utils.serialize_keras_object(em)
165+
json = tf.keras.utils.legacy.serialize_keras_object(em)
166166
weights = em.get_weights()
167167
x = noisy.base.sample([100])
168168
x_quantized = em.quantize(x)
169169
x_compressed = em.compress(x)
170-
em = tf.keras.utils.deserialize_keras_object(json)
170+
em = tf.keras.utils.legacy.deserialize_keras_object(json)
171171
em.set_weights(weights)
172172
self.assertAllEqual(em.compress(x), x_compressed)
173173
self.assertAllEqual(em.decompress(x_compressed, [100]), x_quantized)

tensorflow_compression/python/layers/gdn.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def alpha_parameter(self, value):
218218
self._check_not_built()
219219
# This is necessary to make Keras deserialization via __init__ work.
220220
if isinstance(value, dict):
221-
value = tf.keras.utils.deserialize_keras_object(value)
221+
value = tf.keras.utils.legacy.deserialize_keras_object(value)
222222
if value is not None and not callable(value):
223223
# It's a constant, so keep it in compute_dtype.
224224
value = tf.convert_to_tensor(value, dtype=self.compute_dtype)
@@ -233,7 +233,7 @@ def beta_parameter(self, value):
233233
self._check_not_built()
234234
# This is necessary to make Keras deserialization via __init__ work.
235235
if isinstance(value, dict):
236-
value = tf.keras.utils.deserialize_keras_object(value)
236+
value = tf.keras.utils.legacy.deserialize_keras_object(value)
237237
if value is not None and not callable(value):
238238
# It's a constant, so keep it in compute_dtype.
239239
value = tf.convert_to_tensor(value, dtype=self.compute_dtype)
@@ -248,7 +248,7 @@ def gamma_parameter(self, value):
248248
self._check_not_built()
249249
# This is necessary to make Keras deserialization via __init__ work.
250250
if isinstance(value, dict):
251-
value = tf.keras.utils.deserialize_keras_object(value)
251+
value = tf.keras.utils.legacy.deserialize_keras_object(value)
252252
if value is not None and not callable(value):
253253
# It's a constant, so keep it in compute_dtype.
254254
value = tf.convert_to_tensor(value, dtype=self.compute_dtype)
@@ -263,7 +263,7 @@ def epsilon_parameter(self, value):
263263
self._check_not_built()
264264
# This is necessary to make Keras deserialization via __init__ work.
265265
if isinstance(value, dict):
266-
value = tf.keras.utils.deserialize_keras_object(value)
266+
value = tf.keras.utils.legacy.deserialize_keras_object(value)
267267
if value is not None and not callable(value):
268268
# It's a constant, so keep it in compute_dtype.
269269
value = tf.convert_to_tensor(value, dtype=self.compute_dtype)
@@ -431,7 +431,7 @@ def try_serialize(parameter, name):
431431
if parameter is None:
432432
return None
433433
try:
434-
return tf.keras.utils.serialize_keras_object(parameter)
434+
return tf.keras.utils.legacy.serialize_keras_object(parameter)
435435
except (ValueError, TypeError): # Should throw TypeError, but doesn't...
436436
try:
437437
return float(parameter)
@@ -449,18 +449,22 @@ def try_serialize(parameter, name):
449449
rectify=self.rectify,
450450
data_format=self.data_format,
451451
alpha_parameter=alpha_parameter,
452-
beta_parameter=tf.keras.utils.serialize_keras_object(
453-
self.beta_parameter),
454-
gamma_parameter=tf.keras.utils.serialize_keras_object(
455-
self.gamma_parameter),
452+
beta_parameter=tf.keras.utils.legacy.serialize_keras_object(
453+
self.beta_parameter
454+
),
455+
gamma_parameter=tf.keras.utils.legacy.serialize_keras_object(
456+
self.gamma_parameter
457+
),
456458
epsilon_parameter=epsilon_parameter,
457459
alpha_initializer=tf.keras.initializers.serialize(
458-
self.alpha_initializer),
459-
beta_initializer=tf.keras.initializers.serialize(
460-
self.beta_initializer),
460+
self.alpha_initializer
461+
),
462+
beta_initializer=tf.keras.initializers.serialize(self.beta_initializer),
461463
gamma_initializer=tf.keras.initializers.serialize(
462-
self.gamma_initializer),
464+
self.gamma_initializer
465+
),
463466
epsilon_initializer=tf.keras.initializers.serialize(
464-
self.epsilon_initializer),
467+
self.epsilon_initializer
468+
),
465469
)
466470
return config

tensorflow_compression/python/layers/parameters_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def test_name_and_value_are_reproduced_after_serialization(self):
3131
parameter = self.cls(initial_value, **self.kwargs)
3232
name_before = parameter.name
3333
value_before = parameter()
34-
json = tf.keras.utils.serialize_keras_object(parameter)
34+
json = tf.keras.utils.legacy.serialize_keras_object(parameter)
3535
weights = parameter.get_weights()
36-
parameter = tf.keras.utils.deserialize_keras_object(json)
36+
parameter = tf.keras.utils.legacy.deserialize_keras_object(json)
3737
self.assertIsInstance(parameter, self.cls)
3838
self.assertEqual(name_before, parameter.name)
3939
parameter.set_weights(weights)

tensorflow_compression/python/layers/signal_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def kernel_parameter(self, value):
480480
self._check_not_built()
481481
# This is necessary to make Keras deserialization via __init__ work.
482482
if isinstance(value, dict):
483-
value = tf.keras.utils.deserialize_keras_object(value)
483+
value = tf.keras.utils.legacy.deserialize_keras_object(value)
484484
if isinstance(value, str):
485485
if value not in ("variable", "rdft"):
486486
raise ValueError(f"Unsupported value for kernel_parameter: '{value}'.")
@@ -498,7 +498,7 @@ def bias_parameter(self, value):
498498
self._check_not_built()
499499
# This is necessary to make Keras deserialization via __init__ work.
500500
if isinstance(value, dict):
501-
value = tf.keras.utils.deserialize_keras_object(value)
501+
value = tf.keras.utils.legacy.deserialize_keras_object(value)
502502
if isinstance(value, str):
503503
if value != "variable":
504504
raise ValueError(f"Unsupported value for bias_parameter: '{value}'.")
@@ -991,7 +991,7 @@ def try_serialize(parameter, name):
991991
if isinstance(parameter, str):
992992
return parameter
993993
try:
994-
return tf.keras.utils.serialize_keras_object(parameter)
994+
return tf.keras.utils.legacy.serialize_keras_object(parameter)
995995
except (ValueError, TypeError): # Should throw TypeError, but doesn't...
996996
if isinstance(parameter, tf.Variable):
997997
return "variable"

0 commit comments

Comments
 (0)