Skip to content

Commit a013182

Browse files
liyunlu0618tensorflower-gardener
authored andcommitted
Use public tf.keras model save/load method.
PiperOrigin-RevId: 303836282
1 parent 5dbfbca commit a013182

File tree

1 file changed

+4
-8
lines changed
  • tensorflow_model_optimization/python/core/sparsity/keras

1 file changed

+4
-8
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/test_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,16 @@ def _save_restore_keras_model(model):
145145
return loaded_model
146146

147147

148-
def _save_restore_saved_model(model):
148+
def _save_restore_tf_model(model):
149149
tmpdir = tempfile.mkdtemp()
150-
tf.keras.experimental.export_saved_model(model, tmpdir)
151-
150+
tf.keras.models.save_model(model, tmpdir, save_format='tf')
152151
with prune.prune_scope():
153-
loaded_model = tf.keras.experimental.load_from_saved_model(tmpdir)
154-
155-
loaded_model.compile(
156-
loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
152+
loaded_model = tf.keras.models.load_model(tmpdir)
157153
return loaded_model
158154

159155

160156
def save_restore_fns():
161-
return [_save_restore_keras_model, _save_restore_saved_model]
157+
return [_save_restore_keras_model, _save_restore_tf_model]
162158

163159

164160
# Assertion/Sparsity Verification functions.

0 commit comments

Comments
 (0)