Skip to content

Commit 8aea6cc

Browse files
Update distributed model saving/loading tutorial: Add example showing how saving works after calling .fit. Also fix some typos, linting, adding some minor details.
PiperOrigin-RevId: 438140288
1 parent c1278c1 commit 8aea6cc

File tree

1 file changed

+123
-41
lines changed

1 file changed

+123
-41
lines changed

site/en/tutorials/distribute/save_and_load.ipynb

Lines changed: 123 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@
7171
"source": [
7272
"## Overview\n",
7373
"\n",
74-
"It's common to save and load a model during training. There are two sets of APIs for saving and loading a keras model: a high-level API, and a low-level API. This tutorial demonstrates how you can use the SavedModel APIs when using `tf.distribute.Strategy`. To learn about SavedModel and serialization in general, please read the [saved model guide](../../guide/saved_model.ipynb), and the [Keras model serialization guide](https://www.tensorflow.org/guide/keras/save_and_serialize). Let's start with a simple example: "
74+
"This tutorial demonstrates how you can save and load models in a SavedModel format with `tf.distribute.Strategy` during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (`tf.keras.Model.save` and `tf.keras.models.load_model`) and low-level (`tf.saved_model.save` and `tf.saved_model.load`).\n",
75+
"\n",
76+
"To learn about SavedModel and serialization in general, please read the [saved model guide](../../guide/saved_model.ipynb), and the [Keras model serialization guide](https://www.tensorflow.org/guide/keras/save_and_serialize). Let's start with a simple example: "
7577
]
7678
},
7779
{
@@ -102,7 +104,7 @@
102104
"id": "qqapWj98ptNV"
103105
},
104106
"source": [
105-
"Prepare the data and model using `tf.distribute.Strategy`:"
107+
"Load and prepare the data with TensorFlow Datasets and `tf.data`, and create the model using `tf.distribute.MirroredStrategy`:"
106108
]
107109
},
108110
{
@@ -116,7 +118,7 @@
116118
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
117119
"\n",
118120
"def get_data():\n",
119-
" datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)\n",
121+
" datasets = tfds.load(name='mnist', as_supervised=True)\n",
120122
" mnist_train, mnist_test = datasets['train'], datasets['test']\n",
121123
"\n",
122124
" BUFFER_SIZE = 10000\n",
@@ -157,7 +159,7 @@
157159
"id": "qmU4Y3feS9Na"
158160
},
159161
"source": [
160-
"Train the model: "
162+
"Train the model with `tf.keras.Model.fit`: "
161163
]
162164
},
163165
{
@@ -181,11 +183,11 @@
181183
"source": [
182184
"## Save and load the model\n",
183185
"\n",
184-
"Now that you have a simple model to work with, let's take a look at the saving/loading APIs. \n",
185-
"There are two sets of APIs available:\n",
186+
"Now that you have a simple model to work with, let's explore the saving/loading APIs. \n",
187+
"There are two kinds of APIs available:\n",
186188
"\n",
187-
"* High level keras `model.save` and `tf.keras.models.load_model`\n",
188-
"* Low level `tf.saved_model.save` and `tf.saved_model.load`\n"
189+
"* High-level (Keras): `Model.save` and `tf.keras.models.load_model`\n",
190+
"* Low-level: `tf.saved_model.save` and `tf.saved_model.load`\n"
189191
]
190192
},
191193
{
@@ -194,7 +196,7 @@
194196
"id": "FX_IF2F1tvFs"
195197
},
196198
"source": [
197-
"### The Keras APIs"
199+
"### The Keras API"
198200
]
199201
},
200202
{
@@ -203,7 +205,7 @@
203205
"id": "O8xfceg4Z3H_"
204206
},
205207
"source": [
206-
"Here is an example of saving and loading a model with the Keras APIs:"
208+
"Here is an example of saving and loading a model with the Keras API:"
207209
]
208210
},
209211
{
@@ -214,7 +216,7 @@
214216
},
215217
"outputs": [],
216218
"source": [
217-
"keras_model_path = \"/tmp/keras_save\"\n",
219+
"keras_model_path = '/tmp/keras_save'\n",
218220
"model.save(keras_model_path)"
219221
]
220222
},
@@ -245,9 +247,9 @@
245247
"id": "gYAnskzorda-"
246248
},
247249
"source": [
248-
"After restoring the model, you can continue training on it, even without needing to call `compile()` again, since it is already compiled before saving. The model is saved in the TensorFlow's standard `SavedModel` proto format. For more information, please refer to [the guide to `saved_model` format](../../guide/saved_model.ipynb).\n",
250+
"After restoring the model, you can continue training on it, even without needing to call `Model.compile` again, since it was already compiled before saving. The model is saved in TensorFlow's standard `SavedModel` proto format. For more information, please refer to [the guide to `SavedModel` format](../../guide/saved_model.ipynb).\n",
249251
"\n",
250-
"Now to load the model and train it using a `tf.distribute.Strategy`:"
252+
"Now, restore the model and train it using a `tf.distribute.Strategy`:"
251253
]
252254
},
253255
{
@@ -258,7 +260,7 @@
258260
},
259261
"outputs": [],
260262
"source": [
261-
"another_strategy = tf.distribute.OneDeviceStrategy(\"/cpu:0\")\n",
263+
"another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')\n",
262264
"with another_strategy.scope():\n",
263265
" restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)\n",
264266
" restored_keras_model_ds.fit(train_dataset, epochs=2)"
@@ -270,7 +272,7 @@
270272
"id": "PdiiPmL5tQk5"
271273
},
272274
"source": [
273-
"As you can see, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. "
275+
"As the `Model.fit` output shows, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. "
274276
]
275277
},
276278
{
@@ -279,7 +281,7 @@
279281
"id": "3CrXIbmFt0f6"
280282
},
281283
"source": [
282-
"### The `tf.saved_model` APIs"
284+
"### The `tf.saved_model` API"
283285
]
284286
},
285287
{
@@ -288,7 +290,7 @@
288290
"id": "HtGzPp6et4Em"
289291
},
290292
"source": [
291-
"Now let's take a look at the lower level APIs. Saving the model is similar to the keras API:"
293+
"Saving the model with lower-level API is similar to the Keras API:"
292294
]
293295
},
294296
{
@@ -300,7 +302,7 @@
300302
"outputs": [],
301303
"source": [
302304
"model = get_model() # get a fresh model\n",
303-
"saved_model_path = \"/tmp/tf_save\"\n",
305+
"saved_model_path = '/tmp/tf_save'\n",
304306
"tf.saved_model.save(model, saved_model_path)"
305307
]
306308
},
@@ -310,7 +312,7 @@
310312
"id": "q1QNRYcwuRll"
311313
},
312314
"source": [
313-
"Loading can be done with `tf.saved_model.load()`. However, since it is an API that is on the lower level (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:"
315+
"Loading can be done with `tf.saved_model.load`. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:"
314316
]
315317
},
316318
{
@@ -321,7 +323,7 @@
321323
},
322324
"outputs": [],
323325
"source": [
324-
"DEFAULT_FUNCTION_KEY = \"serving_default\"\n",
326+
"DEFAULT_FUNCTION_KEY = 'serving_default'\n",
325327
"loaded = tf.saved_model.load(saved_model_path)\n",
326328
"inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]"
327329
]
@@ -332,7 +334,7 @@
332334
"id": "x65l7AaHUZCA"
333335
},
334336
"source": [
335-
"The loaded object may contain multiple functions, each associated with a key. The `\"serving_default\"` is the default key for the inference function with a saved Keras model. To do an inference with this function: "
337+
"The loaded object may contain multiple functions, each associated with a key. The `\"serving_default\"` key is the default key for the inference function with a saved Keras model. To do inference with this function: "
336338
]
337339
},
338340
{
@@ -375,7 +377,9 @@
375377
"\n",
376378
" # Calling the function in a distributed manner\n",
377379
" for batch in dist_predict_dataset:\n",
378-
" another_strategy.run(inference_func,args=(batch,))"
380+
" result = another_strategy.run(inference_func, args=(batch,))\n",
381+
" print(result)\n",
382+
" break"
379383
]
380384
},
381385
{
@@ -384,7 +388,7 @@
384388
"id": "hWGSukoyw3fF"
385389
},
386390
"source": [
387-
"Calling the restored function is just a forward pass on the saved model (predict). What if yout want to continue training the loaded function? Or embed the loaded function into a bigger model? A common practice is to wrap this loaded object to a Keras layer to achieve this. Luckily, [TF Hub](https://www.tensorflow.org/hub) has [hub.KerasLayer](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) for this purpose, shown here:"
391+
"Calling the restored function is just a forward pass on the saved model (`tf.keras.Model.predict`). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, [TF Hub](https://www.tensorflow.org/hub) has [`hub.KerasLayer`](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) for this purpose, shown here:"
388392
]
389393
},
390394
{
@@ -421,7 +425,7 @@
421425
"id": "Oe1z_OtSJlu2"
422426
},
423427
"source": [
424-
"As you can see, `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load()` into a Keras layer that can be used to build another model. This is very useful for transfer learning. "
428+
"In the above example, Tensorflow Hub's `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load` into a Keras layer that is used to build another model. This is very useful for transfer learning. "
425429
]
426430
},
427431
{
@@ -439,11 +443,11 @@
439443
"id": "GC6GQ9HDLxD6"
440444
},
441445
"source": [
442-
"For saving, if you are working with a keras model, it is almost always recommended to use the Keras's `model.save()` API. If what you are saving is not a Keras model, then the lower level API is your only choice. \n",
446+
"For saving, if you are working with a Keras model, use the Keras `Model.save` API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, `tf.saved_model.save`, is your only choice. \n",
443447
"\n",
444-
"For loading, which API you use depends on what you want to get from the loading API. If you cannot (or do not want to) get a Keras model, then use `tf.saved_model.load()`. Otherwise, use `tf.keras.models.load_model()`. Note that you can get a Keras model back only if you saved a Keras model. \n",
448+
"For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use `tf.saved_model.load`. Otherwise, use `tf.keras.models.load_model`. Note that you can get a Keras model back only if you saved a Keras model. \n",
445449
"\n",
446-
"It is possible to mix and match the APIs. You can save a Keras model with `model.save`, and load a non-Keras model with the low-level API, `tf.saved_model.load`. "
450+
"It is possible to mix and match the APIs. You can save a Keras model with `Model.save`, and load a non-Keras model with the low-level API, `tf.saved_model.load`. "
447451
]
448452
},
449453
{
@@ -456,11 +460,11 @@
456460
"source": [
457461
"model = get_model()\n",
458462
"\n",
459-
"# Saving the model using Keras's save() API\n",
460-
"model.save(keras_model_path) \n",
463+
"# Saving the model using Keras `Model.save`\n",
464+
"model.save(keras_model_path)\n",
461465
"\n",
462466
"another_strategy = tf.distribute.MirroredStrategy()\n",
463-
"# Loading the model using lower level API\n",
467+
"# Loading the model using the lower-level API\n",
464468
"with another_strategy.scope():\n",
465469
" loaded = tf.saved_model.load(keras_model_path)"
466470
]
@@ -471,7 +475,7 @@
471475
"id": "0Z7lSj8nZiW5"
472476
},
473477
"source": [
474-
"### Saving/Loading from local device"
478+
"### Saving/Loading from a local device"
475479
]
476480
},
477481
{
@@ -480,7 +484,7 @@
480484
"id": "NVAjWcosZodw"
481485
},
482486
"source": [
483-
"When saving and loading from a local io device while running remotely, for example using a cloud TPU, the option `experimental_io_device` must be used to set the io device to localhost."
487+
"When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option `experimental_io_device` in `tf.saved_model.SaveOptions` and `tf.saved_model.LoadOptions` to set the I/O device to `localhost`. For example:"
484488
]
485489
},
486490
{
@@ -494,7 +498,7 @@
494498
"model = get_model()\n",
495499
"\n",
496500
"# Saving the model to a path on localhost.\n",
497-
"saved_model_path = \"/tmp/tf_save\"\n",
501+
"saved_model_path = '/tmp/tf_save'\n",
498502
"save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')\n",
499503
"model.save(saved_model_path, options=save_options)\n",
500504
"\n",
@@ -517,14 +521,10 @@
517521
{
518522
"cell_type": "markdown",
519523
"metadata": {
520-
"id": "Tzog2ti7YYgy"
524+
"id": "2cCSZrD7VCxe"
521525
},
522526
"source": [
523-
"A special case is when you have a Keras model that does not have well-defined inputs. For example, a Sequential model can be created without any input shapes (`Sequential([Dense(3), ...]`). Subclassed models also do not have well-defined inputs after initialization. In this case, you should stick with the lower level APIs on both saving and loading, otherwise you will get an error. \n",
524-
"\n",
525-
"To check if your model has well-defined inputs, just check if `model.inputs` is `None`. If it is not `None`, you are all good. Input shapes are automatically defined when the model is used in `.fit`, `.evaluate`, `.predict`, or when calling the model (`model(inputs)`). \n",
526-
"\n",
527-
"Here is an example:"
527+
"One special case is when you create Keras models in certain ways, and then save them before training. For example:"
528528
]
529529
},
530530
{
@@ -536,6 +536,7 @@
536536
"outputs": [],
537537
"source": [
538538
"class SubclassedModel(tf.keras.Model):\n",
539+
" \"\"\"Example model defined by subclassing `tf.keras.Model`.\"\"\"\n",
539540
"\n",
540541
" output_name = 'output_layer'\n",
541542
"\n",
@@ -548,8 +549,89 @@
548549
" return self._dense_layer(inputs)\n",
549550
"\n",
550551
"my_model = SubclassedModel()\n",
551-
"# my_model.save(keras_model_path) # ERROR! \n",
552-
"tf.saved_model.save(my_model, saved_model_path)"
552+
"try:\n",
553+
" my_model.save(keras_model_path)\n",
554+
"except ValueError as e:\n",
555+
" print(f'{type(e).__name__}: ', *e.args)"
556+
]
557+
},
558+
{
559+
"cell_type": "markdown",
560+
"metadata": {
561+
"id": "D4qMyXFDSPDO"
562+
},
563+
"source": [
564+
"A SavedModel saves the `tf.types.experimental.ConcreteFunction` objects generated when you trace a `tf.function` (check _When is a Function tracing?_ in the [Introduction to graphs and tf.function](../../guide/intro_to_graphs.ipynb) guide to learn more). If you get a `ValueError` like this it's because `Model.save` was not able to find or create a traced `ConcreteFunction`.\n",
565+
"\n",
566+
"**Caution:** You shouldn't save a model without at least one `ConcreteFunction`, since the low-level API will otherwise generate a SavedModel with no `ConcreteFunction` signatures ([learn more](../../guide/saved_model.ipynb) about the SavedModel format). For example:"
567+
]
568+
},
569+
{
570+
"cell_type": "code",
571+
"execution_count": null,
572+
"metadata": {
573+
"id": "064SE47mYDj8"
574+
},
575+
"outputs": [],
576+
"source": [
577+
"tf.saved_model.save(my_model, saved_model_path)\n",
578+
"x = tf.saved_model.load(saved_model_path)\n",
579+
"x.signatures"
580+
]
581+
},
582+
{
583+
"cell_type": "markdown",
584+
"metadata": {
585+
"id": "LRTxlASJX-cY"
586+
},
587+
"source": [
588+
"\n",
589+
"Usually the model's forward pass—the `call` method—will be traced automatically when the model is called for the first time, often via the Keras `Model.fit` method. A `ConcreteFunction` can also be generated by the Keras [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) and [Functional](https://www.tensorflow.org/guide/keras/functional) APIs, if you set the input shape, for example, by making the first layer either a `tf.keras.layers.InputLayer` or another layer type, and passing it the `input_shape` keyword argument. \n",
590+
"\n",
591+
"To verify if your model has any traced `ConcreteFunction`s, check if `Model.save_spec` is `None`:"
592+
]
593+
},
594+
{
595+
"cell_type": "code",
596+
"execution_count": null,
597+
"metadata": {
598+
"id": "xAXise4eR0YJ"
599+
},
600+
"outputs": [],
601+
"source": [
602+
"print(my_model.save_spec() is None)"
603+
]
604+
},
605+
{
606+
"cell_type": "markdown",
607+
"metadata": {
608+
"id": "G2G_FQrWJAO3"
609+
},
610+
"source": [
611+
"Let's use `tf.keras.Model.fit` to train the model, and notice that the `save_spec` gets defined and model saving will work:"
612+
]
613+
},
614+
{
615+
"cell_type": "code",
616+
"execution_count": null,
617+
"metadata": {
618+
"id": "cv5LTi0zDkKS"
619+
},
620+
"outputs": [],
621+
"source": [
622+
"BATCH_SIZE_PER_REPLICA = 4\n",
623+
"BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync\n",
624+
"\n",
625+
"dataset_size = 100\n",
626+
"dataset = tf.data.Dataset.from_tensors(\n",
627+
" (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))\n",
628+
" ).repeat(dataset_size).batch(BATCH_SIZE)\n",
629+
"\n",
630+
"my_model.compile(optimizer='adam', loss='mean_squared_error')\n",
631+
"my_model.fit(dataset, epochs=2)\n",
632+
"\n",
633+
"print(my_model.save_spec() is None)\n",
634+
"my_model.save(keras_model_path)"
553635
]
554636
}
555637
],

0 commit comments

Comments
 (0)