Skip to content

Commit 7df1042

Browse files
Merge pull request #2324 from sanskarmodi8:issue#75196-fix
PiperOrigin-RevId: 673591318
2 parents 45a981f + b90b419 commit 7df1042

File tree

1 file changed

+28
-58
lines changed

1 file changed

+28
-58
lines changed

site/en/tutorials/keras/save_and_load.ipynb

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@
142142
},
143143
"outputs": [],
144144
"source": [
145-
"!pip install pyyaml h5py # Required to save models in HDF5 format"
145+
"!pip install pyyaml h5py # Required to save models in HDF5 format."
146146
]
147147
},
148148
{
@@ -266,7 +266,7 @@
266266
},
267267
"outputs": [],
268268
"source": [
269-
"checkpoint_path = \"training_1/cp.ckpt\"\n",
269+
"checkpoint_path = \"training_1/cp.weights.h5\" # we are only saving weights therefore we need to use .weights.h5 extension instead we would use .keras for whole model\n",
270270
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
271271
"\n",
272272
"# Create a callback that saves the model's weights\n",
@@ -275,8 +275,8 @@
275275
" verbose=1)\n",
276276
"\n",
277277
"# Train the model with the new callback\n",
278-
"model.fit(train_images, \n",
279-
" train_labels, \n",
278+
"model.fit(train_images,\n",
279+
" train_labels,\n",
280280
" epochs=10,\n",
281281
" validation_data=(test_images, test_labels),\n",
282282
" callbacks=[cp_callback]) # Pass callback to training\n",
@@ -312,7 +312,7 @@
312312
"id": "wlRN_f56Pqa9"
313313
},
314314
"source": [
315-
"As long as two models share the same architecture you can share weights between them. So, when restoring a model from weights-only, create a model with the same architecture as the original model and then set its weights. \n",
315+
"As long as two models share the same architecture you can share weights between them. So, when restoring a model from weights-only, create a model with the same architecture as the original model and then set its weights.\n",
316316
"\n",
317317
"Now rebuild a fresh, untrained model and evaluate it on the test set. An untrained model will perform at chance levels (~10% accuracy):"
318318
]
@@ -380,8 +380,9 @@
380380
"outputs": [],
381381
"source": [
382382
"# Include the epoch in the file name (uses `str.format`)\n",
383-
"checkpoint_path = \"training_2/cp-{epoch:04d}.ckpt\"\n",
383+
"checkpoint_path = \"training_2/cp-{epoch:04d}.weights.h5\"\n",
384384
"checkpoint_dir = os.path.dirname(checkpoint_path)\n",
385+
"os.mkdir(checkpoint_dir)\n",
385386
"\n",
386387
"batch_size = 32\n",
387388
"\n",
@@ -392,8 +393,8 @@
392393
"\n",
393394
"# Create a callback that saves the model's weights every 5 epochs\n",
394395
"cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
395-
" filepath=checkpoint_path, \n",
396-
" verbose=1, \n",
396+
" filepath=checkpoint_path,\n",
397+
" verbose=1,\n",
397398
" save_weights_only=True,\n",
398399
" save_freq=5*n_batches)\n",
399400
"\n",
@@ -404,10 +405,10 @@
404405
"model.save_weights(checkpoint_path.format(epoch=0))\n",
405406
"\n",
406407
"# Train the model with the new callback\n",
407-
"model.fit(train_images, \n",
408+
"model.fit(train_images,\n",
408409
" train_labels,\n",
409-
" epochs=50, \n",
410-
" batch_size=batch_size, \n",
410+
" epochs=50,\n",
411+
" batch_size=batch_size,\n",
411412
" callbacks=[cp_callback],\n",
412413
" validation_data=(test_images, test_labels),\n",
413414
" verbose=0)"
@@ -441,7 +442,11 @@
441442
},
442443
"outputs": [],
443444
"source": [
444-
"latest = tf.train.latest_checkpoint(checkpoint_dir)\n",
445+
"def load_latest_checkpoint(checkpoint_dir):\n",
446+
" latest = max(os.listdir(checkpoint_dir), key=lambda f: int(f.split('-')[1].split('.')[0]))\n",
447+
" return os.path.join(checkpoint_dir, latest)\n",
448+
"\n",
449+
"latest = load_latest_checkpoint(checkpoint_dir)\n",
445450
"latest"
446451
]
447452
},
@@ -505,7 +510,7 @@
505510
"source": [
506511
"## Manually save weights\n",
507512
"\n",
508-
"To save weights manually, use `tf.keras.Model.save_weights`. By default, `tf.keras`—and the `Model.save_weights` method in particular—uses the TensorFlow [Checkpoint](../../guide/checkpoint.ipynb) format with a `.ckpt` extension. To save in the HDF5 format with a `.h5` extension, refer to the [Save and load models](https://www.tensorflow.org/guide/keras/save_and_serialize) guide."
513+
"To save weights manually, use `tf.keras.Model.save_weights`. You have to use .weights.h5 extension to save the weights. You can refer to the [Save and load models](https://www.tensorflow.org/guide/keras/save_and_serialize) guide."
509514
]
510515
},
511516
{
@@ -517,13 +522,14 @@
517522
"outputs": [],
518523
"source": [
519524
"# Save the weights\n",
520-
"model.save_weights('./checkpoints/my_checkpoint')\n",
525+
"os.mkdir('./checkpoints')\n",
526+
"model.save_weights('./checkpoints/my_checkpoint.weights.h5')\n",
521527
"\n",
522528
"# Create a new model instance\n",
523529
"model = create_model()\n",
524530
"\n",
525531
"# Restore the weights\n",
526-
"model.load_weights('./checkpoints/my_checkpoint')\n",
532+
"model.load_weights('./checkpoints/my_checkpoint.weights.h5')\n",
527533
"\n",
528534
"# Evaluate the model\n",
529535
"loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
@@ -657,7 +663,7 @@
657663
"id": "LtcN4VIb7JkK"
658664
},
659665
"source": [
660-
"The SavedModel format is another way to serialize models. Models saved in this format can be restored using `tf.keras.models.load_model` and are compatible with TensorFlow Serving. The [SavedModel guide](../../guide/saved_model.ipynb) goes into detail about how to `serve/inspect` the SavedModel. The section below illustrates the steps to save and restore the model."
666+
"The SavedModel format is another way to serialize models. Models saved in this format can directly be used with TFLite/TFServing/etc for inferencing. The [SavedModel guide](../../guide/saved_model.ipynb) goes into detail about how to `serve/inspect` the SavedModel. The section below illustrates the steps to save and restore the model."
661667
]
662668
},
663669
{
@@ -674,7 +680,7 @@
674680
"\n",
675681
"# Save the entire model as a SavedModel.\n",
676682
"!mkdir -p saved_model\n",
677-
"model.save('saved_model/my_model') "
683+
"tf.saved_model.save(model, 'saved_model/my_model')"
678684
]
679685
},
680686
{
@@ -707,7 +713,7 @@
707713
"id": "B7qfpvpY9HCe"
708714
},
709715
"source": [
710-
"Reload a fresh Keras model from the saved model:"
716+
"Reload the saved SavedModel file:"
711717
]
712718
},
713719
{
@@ -718,34 +724,8 @@
718724
},
719725
"outputs": [],
720726
"source": [
721-
"new_model = tf.keras.models.load_model('saved_model/my_model')\n",
722-
"\n",
723-
"# Check its architecture\n",
724-
"new_model.summary()"
725-
]
726-
},
727-
{
728-
"cell_type": "markdown",
729-
"metadata": {
730-
"id": "uWwgNaz19TH2"
731-
},
732-
"source": [
733-
"The restored model is compiled with the same arguments as the original model. Try running evaluate and predict with the loaded model:"
734-
]
735-
},
736-
{
737-
"cell_type": "code",
738-
"execution_count": null,
739-
"metadata": {
740-
"id": "Yh5Mu0yOgE5J"
741-
},
742-
"outputs": [],
743-
"source": [
744-
"# Evaluate the restored model\n",
745-
"loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
746-
"print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n",
747-
"\n",
748-
"print(new_model.predict(test_images).shape)"
727+
"saved_model = tf.saved_model.load('saved_model/my_model')\n",
728+
"saved_model"
749729
]
750730
},
751731
{
@@ -756,7 +736,7 @@
756736
"source": [
757737
"### HDF5 format\n",
758738
"\n",
759-
"Keras provides a basic legacy high-level save format using the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) standard. "
739+
"Keras provides a basic legacy high-level save format using the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) standard."
760740
]
761741
},
762742
{
@@ -773,7 +753,7 @@
773753
"\n",
774754
"# Save the entire model to a HDF5 file.\n",
775755
"# The '.h5' extension indicates that the model should be saved to HDF5.\n",
776-
"model.save('my_model.h5') "
756+
"model.save('my_model.h5')"
777757
]
778758
},
779759
{
@@ -859,21 +839,11 @@
859839
"\n",
860840
"Refer to the [Writing layers and models from scratch](https://www.tensorflow.org/guide/keras/custom_layers_and_models) tutorial for examples of custom objects and `get_config`.\n"
861841
]
862-
},
863-
{
864-
"cell_type": "code",
865-
"execution_count": null,
866-
"metadata": {
867-
"id": "jBVTkkUIkEF3"
868-
},
869-
"outputs": [],
870-
"source": []
871842
}
872843
],
873844
"metadata": {
874845
"accelerator": "GPU",
875846
"colab": {
876-
"collapsed_sections": [],
877847
"name": "save_and_load.ipynb",
878848
"toc_visible": true
879849
},

0 commit comments

Comments
 (0)