|
71 | 71 | "source": [
|
72 | 72 | "## Overview\n",
|
73 | 73 | "\n",
|
74 |
| - "It's common to save and load a model during training. There are two sets of APIs for saving and loading a keras model: a high-level API, and a low-level API. This tutorial demonstrates how you can use the SavedModel APIs when using `tf.distribute.Strategy`. To learn about SavedModel and serialization in general, please read the [saved model guide](../../guide/saved_model.ipynb), and the [Keras model serialization guide](https://www.tensorflow.org/guide/keras/save_and_serialize). Let's start with a simple example: " |
| 74 | + "This tutorial demonstrates how you can save and load models in a SavedModel format with `tf.distribute.Strategy` during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (`tf.keras.Model.save` and `tf.keras.models.load_model`) and low-level (`tf.saved_model.save` and `tf.saved_model.load`).\n", |
| 75 | + "\n", |
| 76 | + "To learn about SavedModel and serialization in general, please read the [saved model guide](../../guide/saved_model.ipynb), and the [Keras model serialization guide](https://www.tensorflow.org/guide/keras/save_and_serialize). Let's start with a simple example: " |
75 | 77 | ]
|
76 | 78 | },
|
77 | 79 | {
|
|
102 | 104 | "id": "qqapWj98ptNV"
|
103 | 105 | },
|
104 | 106 | "source": [
|
105 |
| - "Prepare the data and model using `tf.distribute.Strategy`:" |
| 107 | + "Load and prepare the data with TensorFlow Datasets and `tf.data`, and create the model using `tf.distribute.MirroredStrategy`:" |
106 | 108 | ]
|
107 | 109 | },
|
108 | 110 | {
|
|
116 | 118 | "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
|
117 | 119 | "\n",
|
118 | 120 | "def get_data():\n",
|
119 |
| - " datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)\n", |
| 121 | + " datasets = tfds.load(name='mnist', as_supervised=True)\n", |
120 | 122 | " mnist_train, mnist_test = datasets['train'], datasets['test']\n",
|
121 | 123 | "\n",
|
122 | 124 | " BUFFER_SIZE = 10000\n",
|
|
157 | 159 | "id": "qmU4Y3feS9Na"
|
158 | 160 | },
|
159 | 161 | "source": [
|
160 |
| - "Train the model: " |
| 162 | + "Train the model with `tf.keras.Model.fit`: " |
161 | 163 | ]
|
162 | 164 | },
|
163 | 165 | {
|
|
181 | 183 | "source": [
|
182 | 184 | "## Save and load the model\n",
|
183 | 185 | "\n",
|
184 |
| - "Now that you have a simple model to work with, let's take a look at the saving/loading APIs. \n", |
185 |
| - "There are two sets of APIs available:\n", |
| 186 | + "Now that you have a simple model to work with, let's explore the saving/loading APIs. \n", |
| 187 | + "There are two kinds of APIs available:\n", |
186 | 188 | "\n",
|
187 |
| - "* High level keras `model.save` and `tf.keras.models.load_model`\n", |
188 |
| - "* Low level `tf.saved_model.save` and `tf.saved_model.load`\n" |
| 189 | + "* High-level (Keras): `Model.save` and `tf.keras.models.load_model`\n", |
| 190 | + "* Low-level: `tf.saved_model.save` and `tf.saved_model.load`\n" |
189 | 191 | ]
|
190 | 192 | },
|
191 | 193 | {
|
|
194 | 196 | "id": "FX_IF2F1tvFs"
|
195 | 197 | },
|
196 | 198 | "source": [
|
197 |
| - "### The Keras APIs" |
| 199 | + "### The Keras API" |
198 | 200 | ]
|
199 | 201 | },
|
200 | 202 | {
|
|
203 | 205 | "id": "O8xfceg4Z3H_"
|
204 | 206 | },
|
205 | 207 | "source": [
|
206 |
| - "Here is an example of saving and loading a model with the Keras APIs:" |
| 208 | + "Here is an example of saving and loading a model with the Keras API:" |
207 | 209 | ]
|
208 | 210 | },
|
209 | 211 | {
|
|
214 | 216 | },
|
215 | 217 | "outputs": [],
|
216 | 218 | "source": [
|
217 |
| - "keras_model_path = \"/tmp/keras_save\"\n", |
| 219 | + "keras_model_path = '/tmp/keras_save'\n", |
218 | 220 | "model.save(keras_model_path)"
|
219 | 221 | ]
|
220 | 222 | },
|
|
245 | 247 | "id": "gYAnskzorda-"
|
246 | 248 | },
|
247 | 249 | "source": [
|
248 |
| - "After restoring the model, you can continue training on it, even without needing to call `compile()` again, since it is already compiled before saving. The model is saved in the TensorFlow's standard `SavedModel` proto format. For more information, please refer to [the guide to `saved_model` format](../../guide/saved_model.ipynb).\n", |
| 250 | + "After restoring the model, you can continue training on it, even without needing to call `Model.compile` again, since it was already compiled before saving. The model is saved in TensorFlow's standard `SavedModel` proto format. For more information, please refer to [the guide to `SavedModel` format](../../guide/saved_model.ipynb).\n", |
249 | 251 | "\n",
|
250 |
| - "Now to load the model and train it using a `tf.distribute.Strategy`:" |
| 252 | + "Now, restore the model and train it using a `tf.distribute.Strategy`:" |
251 | 253 | ]
|
252 | 254 | },
|
253 | 255 | {
|
|
258 | 260 | },
|
259 | 261 | "outputs": [],
|
260 | 262 | "source": [
|
261 |
| - "another_strategy = tf.distribute.OneDeviceStrategy(\"/cpu:0\")\n", |
| 263 | + "another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')\n", |
262 | 264 | "with another_strategy.scope():\n",
|
263 | 265 | " restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)\n",
|
264 | 266 | " restored_keras_model_ds.fit(train_dataset, epochs=2)"
|
|
270 | 272 | "id": "PdiiPmL5tQk5"
|
271 | 273 | },
|
272 | 274 | "source": [
|
273 |
| - "As you can see, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. " |
| 275 | + "As the `Model.fit` output shows, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. " |
274 | 276 | ]
|
275 | 277 | },
|
276 | 278 | {
|
|
279 | 281 | "id": "3CrXIbmFt0f6"
|
280 | 282 | },
|
281 | 283 | "source": [
|
282 |
| - "### The `tf.saved_model` APIs" |
| 284 | + "### The `tf.saved_model` API" |
283 | 285 | ]
|
284 | 286 | },
|
285 | 287 | {
|
|
288 | 290 | "id": "HtGzPp6et4Em"
|
289 | 291 | },
|
290 | 292 | "source": [
|
291 |
| - "Now let's take a look at the lower level APIs. Saving the model is similar to the keras API:" |
| 293 | + "Saving the model with lower-level API is similar to the Keras API:" |
292 | 294 | ]
|
293 | 295 | },
|
294 | 296 | {
|
|
300 | 302 | "outputs": [],
|
301 | 303 | "source": [
|
302 | 304 | "model = get_model() # get a fresh model\n",
|
303 |
| - "saved_model_path = \"/tmp/tf_save\"\n", |
| 305 | + "saved_model_path = '/tmp/tf_save'\n", |
304 | 306 | "tf.saved_model.save(model, saved_model_path)"
|
305 | 307 | ]
|
306 | 308 | },
|
|
310 | 312 | "id": "q1QNRYcwuRll"
|
311 | 313 | },
|
312 | 314 | "source": [
|
313 |
| - "Loading can be done with `tf.saved_model.load()`. However, since it is an API that is on the lower level (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:" |
| 315 | + "Loading can be done with `tf.saved_model.load`. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:" |
314 | 316 | ]
|
315 | 317 | },
|
316 | 318 | {
|
|
321 | 323 | },
|
322 | 324 | "outputs": [],
|
323 | 325 | "source": [
|
324 |
| - "DEFAULT_FUNCTION_KEY = \"serving_default\"\n", |
| 326 | + "DEFAULT_FUNCTION_KEY = 'serving_default'\n", |
325 | 327 | "loaded = tf.saved_model.load(saved_model_path)\n",
|
326 | 328 | "inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]"
|
327 | 329 | ]
|
|
332 | 334 | "id": "x65l7AaHUZCA"
|
333 | 335 | },
|
334 | 336 | "source": [
|
335 |
| - "The loaded object may contain multiple functions, each associated with a key. The `\"serving_default\"` is the default key for the inference function with a saved Keras model. To do an inference with this function: " |
| 337 | + "The loaded object may contain multiple functions, each associated with a key. The `\"serving_default\"` key is the default key for the inference function with a saved Keras model. To do inference with this function: " |
336 | 338 | ]
|
337 | 339 | },
|
338 | 340 | {
|
|
375 | 377 | "\n",
|
376 | 378 | " # Calling the function in a distributed manner\n",
|
377 | 379 | " for batch in dist_predict_dataset:\n",
|
378 |
| - " another_strategy.run(inference_func,args=(batch,))" |
| 380 | + " result = another_strategy.run(inference_func, args=(batch,))\n", |
| 381 | + " print(result)\n", |
| 382 | + " break" |
379 | 383 | ]
|
380 | 384 | },
|
381 | 385 | {
|
|
384 | 388 | "id": "hWGSukoyw3fF"
|
385 | 389 | },
|
386 | 390 | "source": [
|
387 |
| - "Calling the restored function is just a forward pass on the saved model (predict). What if yout want to continue training the loaded function? Or embed the loaded function into a bigger model? A common practice is to wrap this loaded object to a Keras layer to achieve this. Luckily, [TF Hub](https://www.tensorflow.org/hub) has [hub.KerasLayer](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) for this purpose, shown here:" |
| 391 | + "Calling the restored function is just a forward pass on the saved model (`tf.keras.Model.predict`). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, [TF Hub](https://www.tensorflow.org/hub) has [`hub.KerasLayer`](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) for this purpose, shown here:" |
388 | 392 | ]
|
389 | 393 | },
|
390 | 394 | {
|
|
421 | 425 | "id": "Oe1z_OtSJlu2"
|
422 | 426 | },
|
423 | 427 | "source": [
|
424 |
| - "As you can see, `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load()` into a Keras layer that can be used to build another model. This is very useful for transfer learning. " |
| 428 | + "In the above example, Tensorflow Hub's `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load` into a Keras layer that is used to build another model. This is very useful for transfer learning. " |
425 | 429 | ]
|
426 | 430 | },
|
427 | 431 | {
|
|
439 | 443 | "id": "GC6GQ9HDLxD6"
|
440 | 444 | },
|
441 | 445 | "source": [
|
442 |
| - "For saving, if you are working with a keras model, it is almost always recommended to use the Keras's `model.save()` API. If what you are saving is not a Keras model, then the lower level API is your only choice. \n", |
| 446 | + "For saving, if you are working with a Keras model, use the Keras `Model.save` API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, `tf.saved_model.save`, is your only choice. \n", |
443 | 447 | "\n",
|
444 |
| - "For loading, which API you use depends on what you want to get from the loading API. If you cannot (or do not want to) get a Keras model, then use `tf.saved_model.load()`. Otherwise, use `tf.keras.models.load_model()`. Note that you can get a Keras model back only if you saved a Keras model. \n", |
| 448 | + "For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use `tf.saved_model.load`. Otherwise, use `tf.keras.models.load_model`. Note that you can get a Keras model back only if you saved a Keras model. \n", |
445 | 449 | "\n",
|
446 |
| - "It is possible to mix and match the APIs. You can save a Keras model with `model.save`, and load a non-Keras model with the low-level API, `tf.saved_model.load`. " |
| 450 | + "It is possible to mix and match the APIs. You can save a Keras model with `Model.save`, and load a non-Keras model with the low-level API, `tf.saved_model.load`. " |
447 | 451 | ]
|
448 | 452 | },
|
449 | 453 | {
|
|
456 | 460 | "source": [
|
457 | 461 | "model = get_model()\n",
|
458 | 462 | "\n",
|
459 |
| - "# Saving the model using Keras's save() API\n", |
460 |
| - "model.save(keras_model_path) \n", |
| 463 | + "# Saving the model using Keras `Model.save`\n", |
| 464 | + "model.save(keras_model_path)\n", |
461 | 465 | "\n",
|
462 | 466 | "another_strategy = tf.distribute.MirroredStrategy()\n",
|
463 |
| - "# Loading the model using lower level API\n", |
| 467 | + "# Loading the model using the lower-level API\n", |
464 | 468 | "with another_strategy.scope():\n",
|
465 | 469 | " loaded = tf.saved_model.load(keras_model_path)"
|
466 | 470 | ]
|
|
471 | 475 | "id": "0Z7lSj8nZiW5"
|
472 | 476 | },
|
473 | 477 | "source": [
|
474 |
| - "### Saving/Loading from local device" |
| 478 | + "### Saving/Loading from a local device" |
475 | 479 | ]
|
476 | 480 | },
|
477 | 481 | {
|
|
480 | 484 | "id": "NVAjWcosZodw"
|
481 | 485 | },
|
482 | 486 | "source": [
|
483 |
| - "When saving and loading from a local io device while running remotely, for example using a cloud TPU, the option `experimental_io_device` must be used to set the io device to localhost." |
| 487 | + "When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option `experimental_io_device` in `tf.saved_model.SaveOptions` and `tf.saved_model.LoadOptions` to set the I/O device to `localhost`. For example:" |
484 | 488 | ]
|
485 | 489 | },
|
486 | 490 | {
|
|
494 | 498 | "model = get_model()\n",
|
495 | 499 | "\n",
|
496 | 500 | "# Saving the model to a path on localhost.\n",
|
497 |
| - "saved_model_path = \"/tmp/tf_save\"\n", |
| 501 | + "saved_model_path = '/tmp/tf_save'\n", |
498 | 502 | "save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')\n",
|
499 | 503 | "model.save(saved_model_path, options=save_options)\n",
|
500 | 504 | "\n",
|
|
517 | 521 | {
|
518 | 522 | "cell_type": "markdown",
|
519 | 523 | "metadata": {
|
520 |
| - "id": "Tzog2ti7YYgy" |
| 524 | + "id": "2cCSZrD7VCxe" |
521 | 525 | },
|
522 | 526 | "source": [
|
523 |
| - "A special case is when you have a Keras model that does not have well-defined inputs. For example, a Sequential model can be created without any input shapes (`Sequential([Dense(3), ...]`). Subclassed models also do not have well-defined inputs after initialization. In this case, you should stick with the lower level APIs on both saving and loading, otherwise you will get an error. \n", |
524 |
| - "\n", |
525 |
| - "To check if your model has well-defined inputs, just check if `model.inputs` is `None`. If it is not `None`, you are all good. Input shapes are automatically defined when the model is used in `.fit`, `.evaluate`, `.predict`, or when calling the model (`model(inputs)`). \n", |
526 |
| - "\n", |
527 |
| - "Here is an example:" |
| 527 | + "One special case is when you create Keras models in certain ways, and then save them before training. For example:" |
528 | 528 | ]
|
529 | 529 | },
|
530 | 530 | {
|
|
536 | 536 | "outputs": [],
|
537 | 537 | "source": [
|
538 | 538 | "class SubclassedModel(tf.keras.Model):\n",
|
| 539 | + " \"\"\"Example model defined by subclassing `tf.keras.Model`.\"\"\"\n", |
539 | 540 | "\n",
|
540 | 541 | " output_name = 'output_layer'\n",
|
541 | 542 | "\n",
|
|
548 | 549 | " return self._dense_layer(inputs)\n",
|
549 | 550 | "\n",
|
550 | 551 | "my_model = SubclassedModel()\n",
|
551 |
| - "# my_model.save(keras_model_path) # ERROR! \n", |
552 |
| - "tf.saved_model.save(my_model, saved_model_path)" |
| 552 | + "try:\n", |
| 553 | + " my_model.save(keras_model_path)\n", |
| 554 | + "except ValueError as e:\n", |
| 555 | + " print(f'{type(e).__name__}: ', *e.args)" |
| 556 | + ] |
| 557 | + }, |
| 558 | + { |
| 559 | + "cell_type": "markdown", |
| 560 | + "metadata": { |
| 561 | + "id": "D4qMyXFDSPDO" |
| 562 | + }, |
| 563 | + "source": [ |
| 564 | + "A SavedModel saves the `tf.types.experimental.ConcreteFunction` objects generated when you trace a `tf.function` (check _When is a Function tracing?_ in the [Introduction to graphs and tf.function](../../guide/intro_to_graphs.ipynb) guide to learn more). If you get a `ValueError` like this it's because `Model.save` was not able to find or create a traced `ConcreteFunction`.\n", |
| 565 | + "\n", |
| 566 | + "**Caution:** You shouldn't save a model without at least one `ConcreteFunction`, since the low-level API will otherwise generate a SavedModel with no `ConcreteFunction` signatures ([learn more](../../guide/saved_model.ipynb) about the SavedModel format). For example:" |
| 567 | + ] |
| 568 | + }, |
| 569 | + { |
| 570 | + "cell_type": "code", |
| 571 | + "execution_count": null, |
| 572 | + "metadata": { |
| 573 | + "id": "064SE47mYDj8" |
| 574 | + }, |
| 575 | + "outputs": [], |
| 576 | + "source": [ |
| 577 | + "tf.saved_model.save(my_model, saved_model_path)\n", |
| 578 | + "x = tf.saved_model.load(saved_model_path)\n", |
| 579 | + "x.signatures" |
| 580 | + ] |
| 581 | + }, |
| 582 | + { |
| 583 | + "cell_type": "markdown", |
| 584 | + "metadata": { |
| 585 | + "id": "LRTxlASJX-cY" |
| 586 | + }, |
| 587 | + "source": [ |
| 588 | + "\n", |
| 589 | + "Usually the model's forward pass—the `call` method—will be traced automatically when the model is called for the first time, often via the Keras `Model.fit` method. A `ConcreteFunction` can also be generated by the Keras [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) and [Functional](https://www.tensorflow.org/guide/keras/functional) APIs, if you set the input shape, for example, by making the first layer either a `tf.keras.layers.InputLayer` or another layer type, and passing it the `input_shape` keyword argument. \n", |
| 590 | + "\n", |
| 591 | + "To verify if your model has any traced `ConcreteFunction`s, check if `Model.save_spec` is `None`:" |
| 592 | + ] |
| 593 | + }, |
| 594 | + { |
| 595 | + "cell_type": "code", |
| 596 | + "execution_count": null, |
| 597 | + "metadata": { |
| 598 | + "id": "xAXise4eR0YJ" |
| 599 | + }, |
| 600 | + "outputs": [], |
| 601 | + "source": [ |
| 602 | + "print(my_model.save_spec() is None)" |
| 603 | + ] |
| 604 | + }, |
| 605 | + { |
| 606 | + "cell_type": "markdown", |
| 607 | + "metadata": { |
| 608 | + "id": "G2G_FQrWJAO3" |
| 609 | + }, |
| 610 | + "source": [ |
| 611 | + "Let's use `tf.keras.Model.fit` to train the model, and notice that the `save_spec` gets defined and model saving will work:" |
| 612 | + ] |
| 613 | + }, |
| 614 | + { |
| 615 | + "cell_type": "code", |
| 616 | + "execution_count": null, |
| 617 | + "metadata": { |
| 618 | + "id": "cv5LTi0zDkKS" |
| 619 | + }, |
| 620 | + "outputs": [], |
| 621 | + "source": [ |
| 622 | + "BATCH_SIZE_PER_REPLICA = 4\n", |
| 623 | + "BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync\n", |
| 624 | + "\n", |
| 625 | + "dataset_size = 100\n", |
| 626 | + "dataset = tf.data.Dataset.from_tensors(\n", |
| 627 | + " (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))\n", |
| 628 | + " ).repeat(dataset_size).batch(BATCH_SIZE)\n", |
| 629 | + "\n", |
| 630 | + "my_model.compile(optimizer='adam', loss='mean_squared_error')\n", |
| 631 | + "my_model.fit(dataset, epochs=2)\n", |
| 632 | + "\n", |
| 633 | + "print(my_model.save_spec() is None)\n", |
| 634 | + "my_model.save(keras_model_path)" |
553 | 635 | ]
|
554 | 636 | }
|
555 | 637 | ],
|
|
0 commit comments