|
697 | 697 | "source": [
|
698 | 698 | "### Checkpointing\n",
|
699 | 699 | "\n",
|
700 |
| - "You can checkpoint a DTensor model using `dtensor.DTensorCheckpoint`. The format of a DTensor checkpoint is fully compatible with a Standard TensorFlow Checkpoint. There is ongoing work to consolidate `dtensor.DTensorCheckpoint` into `tf.train.Checkpoint`.\n", |
| 700 | + "You can checkpoint a DTensor model using `tf.train.Checkpoint` out of the box. Saving and restoring sharded DVariables will perform an efficient sharded save and restore. Currently, when using `tf.train.Checkpoint.save` and `tf.train.Checkpoint.restore`, all DVariables must be on the same host mesh, and DVariables and regular variables cannot be saved together. You can learn more about checkpointing in [this guide](../../guide/checkpoint.ipynb).\n", |
701 | 701 | "\n",
|
702 |
| - "When a DTensor checkpoint is restored, `Layout`s of variables can be different from when the checkpoint is saved. This tutorial makes use of this feature to continue the training in the Model Parallel training and Spatial Parallel training sections.\n" |
| 702 | + "When a DTensor checkpoint is restored, `Layout`s of variables can be different from when the checkpoint is saved. That is, saving DTensor models is layout- and mesh-agnostic, and only affects the efficiency of sharded saving. You can save a DTensor model with one mesh and layout and restore it on a different mesh and layout. This tutorial makes use of this feature to continue the training in the Model Parallel training and Spatial Parallel training sections.\n" |
703 | 703 | ]
|
704 | 704 | },
|
705 | 705 | {
|
|
712 | 712 | "source": [
|
713 | 713 | "CHECKPOINT_DIR = tempfile.mkdtemp()\n",
|
714 | 714 | "\n",
|
715 |
| - "def start_checkpoint_manager(mesh, model):\n", |
716 |
| - " ckpt = dtensor.DTensorCheckpoint(mesh, root=model)\n", |
| 715 | + "def start_checkpoint_manager(model):\n", |
| 716 | + " ckpt = tf.train.Checkpoint(root=model)\n", |
717 | 717 | " manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=3)\n",
|
718 | 718 | "\n",
|
719 | 719 | " if manager.latest_checkpoint:\n",
|
720 | 720 | " print(\"Restoring a checkpoint\")\n",
|
721 | 721 | " ckpt.restore(manager.latest_checkpoint).assert_consumed()\n",
|
722 | 722 | " else:\n",
|
723 |
| - " print(\"new training\")\n", |
| 723 | + " print(\"New training\")\n", |
724 | 724 | " return manager\n"
|
725 | 725 | ]
|
726 | 726 | },
|
|
746 | 746 | "outputs": [],
|
747 | 747 | "source": [
|
748 | 748 | "num_epochs = 2\n",
|
749 |
| - "manager = start_checkpoint_manager(mesh, model)\n", |
| 749 | + "manager = start_checkpoint_manager(model)\n", |
750 | 750 | "\n",
|
751 | 751 | "for epoch in range(num_epochs):\n",
|
752 | 752 | " step = 0\n",
|
|
839 | 839 | "outputs": [],
|
840 | 840 | "source": [
|
841 | 841 | "num_epochs = 2\n",
|
842 |
| - "manager = start_checkpoint_manager(mesh, model)\n", |
| 842 | + "manager = start_checkpoint_manager(model)\n", |
843 | 843 | "\n",
|
844 | 844 | "for epoch in range(num_epochs):\n",
|
845 | 845 | " step = 0\n",
|
|
932 | 932 | "source": [
|
933 | 933 | "num_epochs = 2\n",
|
934 | 934 | "\n",
|
935 |
| - "manager = start_checkpoint_manager(mesh, model)\n", |
| 935 | + "manager = start_checkpoint_manager(model)\n", |
936 | 936 | "for epoch in range(num_epochs):\n",
|
937 | 937 | " step = 0\n",
|
938 | 938 | " metrics = {'epoch': epoch}\n",
|
|
956 | 956 | "source": [
|
957 | 957 | "## SavedModel and DTensor\n",
|
958 | 958 | "\n",
|
959 |
| - "The integration of DTensor and SavedModel is still under development. This section only describes the current status quo for TensorFlow 2.9.0.\n", |
| 959 | + "The integration of DTensor and SavedModel is still under development. \n", |
960 | 960 | "\n",
|
961 |
| - "As of TensorFlow 2.9.0, `tf.saved_model` only accepts DTensor models with fully replicated variables.\n", |
962 |
| - "\n", |
963 |
| - "As a workaround, you can convert a DTensor model to a fully replicated one by reloading a checkpoint. However, after a model is saved, all DTensor annotations are lost and the saved signatures can only be used with regular Tensors, not DTensors." |
| 961 | + "As of TensorFlow `2.11`, `tf.saved_model` can save sharded and replicated DTensor models, and saving will do an efficient sharded save on different devices of the mesh. However, after a model is saved, all DTensor annotations are lost and the saved signatures can only be used with regular Tensors, not DTensors." |
964 | 962 | ]
|
965 | 963 | },
|
966 | 964 | {
|
|
975 | 973 | "mlp = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh), \n",
|
976 | 974 | " dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)])\n",
|
977 | 975 | "\n",
|
978 |
| - "manager = start_checkpoint_manager(mesh, mlp)\n", |
| 976 | + "manager = start_checkpoint_manager(mlp)\n", |
979 | 977 | "\n",
|
980 | 978 | "model_for_saving = tf.keras.Sequential([\n",
|
981 | 979 | " text_vectorization,\n",
|
|
0 commit comments