|
194 | 194 | "id": "fLW6D2TzvC-4"
|
195 | 195 | },
|
196 | 196 | "source": [
|
197 |
| - "Next, create an `mnist.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:" |
| 197 | + "Next, create an `mnist_setup.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:" |
198 | 198 | ]
|
199 | 199 | },
|
200 | 200 | {
|
|
205 | 205 | },
|
206 | 206 | "outputs": [],
|
207 | 207 | "source": [
|
208 |
| - "%%writefile mnist.py\n", |
| 208 | + "%%writefile mnist_setup.py\n", |
209 | 209 | "\n",
|
210 | 210 | "import os\n",
|
211 | 211 | "import tensorflow as tf\n",
|
|
256 | 256 | },
|
257 | 257 | "outputs": [],
|
258 | 258 | "source": [
|
259 |
| - "import mnist\n", |
| 259 | + "import mnist_setup\n", |
260 | 260 | "\n",
|
261 | 261 | "batch_size = 64\n",
|
262 |
| - "single_worker_dataset = mnist.mnist_dataset(batch_size)\n", |
263 |
| - "single_worker_model = mnist.build_and_compile_cnn_model()\n", |
| 262 | + "single_worker_dataset = mnist_setup.mnist_dataset(batch_size)\n", |
| 263 | + "single_worker_model = mnist_setup.build_and_compile_cnn_model()\n", |
264 | 264 | "single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)"
|
265 | 265 | ]
|
266 | 266 | },
|
|
492 | 492 | "source": [
|
493 | 493 | "with strategy.scope():\n",
|
494 | 494 | " # Model building/compiling need to be within `strategy.scope()`.\n",
|
495 |
| - " multi_worker_model = mnist.build_and_compile_cnn_model()" |
| 495 | + " multi_worker_model = mnist_setup.build_and_compile_cnn_model()" |
496 | 496 | ]
|
497 | 497 | },
|
498 | 498 | {
|
|
512 | 512 | "source": [
|
513 | 513 | "To actually run with `MultiWorkerMirroredStrategy` you'll need to run worker processes and pass a `TF_CONFIG` to them.\n",
|
514 | 514 | "\n",
|
515 |
| - "Like the `mnist.py` file written earlier, here is the `main.py` that each of the workers will run:" |
| 515 | + "Like the `mnist_setup.py` file written earlier, here is the `main.py` that each of the workers will run:" |
516 | 516 | ]
|
517 | 517 | },
|
518 | 518 | {
|
|
529 | 529 | "import json\n",
|
530 | 530 | "\n",
|
531 | 531 | "import tensorflow as tf\n",
|
532 |
| - "import mnist\n", |
| 532 | + "import mnist_setup\n", |
533 | 533 | "\n",
|
534 | 534 | "per_worker_batch_size = 64\n",
|
535 | 535 | "tf_config = json.loads(os.environ['TF_CONFIG'])\n",
|
|
538 | 538 | "strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
|
539 | 539 | "\n",
|
540 | 540 | "global_batch_size = per_worker_batch_size * num_workers\n",
|
541 |
| - "multi_worker_dataset = mnist.mnist_dataset(global_batch_size)\n", |
| 541 | + "multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)\n", |
542 | 542 | "\n",
|
543 | 543 | "with strategy.scope():\n",
|
544 | 544 | " # Model building/compiling need to be within `strategy.scope()`.\n",
|
545 |
| - " multi_worker_model = mnist.build_and_compile_cnn_model()\n", |
| 545 | + " multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n", |
546 | 546 | "\n",
|
547 | 547 | "\n",
|
548 | 548 | "multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)"
|
|
820 | 820 | "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n",
|
821 | 821 | "\n",
|
822 | 822 | "global_batch_size = 64\n",
|
823 |
| - "multi_worker_dataset = mnist.mnist_dataset(batch_size=64)\n", |
| 823 | + "multi_worker_dataset = mnist_setup.mnist_dataset(batch_size=64)\n", |
824 | 824 | "dataset_no_auto_shard = multi_worker_dataset.with_options(options)"
|
825 | 825 | ]
|
826 | 826 | },
|
|
1146 | 1146 | "\n",
|
1147 | 1147 | "callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]\n",
|
1148 | 1148 | "with strategy.scope():\n",
|
1149 |
| - " multi_worker_model = mnist.build_and_compile_cnn_model()\n", |
| 1149 | + " multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n", |
1150 | 1150 | "multi_worker_model.fit(multi_worker_dataset,\n",
|
1151 | 1151 | " epochs=3,\n",
|
1152 | 1152 | " steps_per_epoch=70,\n",
|
|
0 commit comments