Skip to content

Commit 2d9aa29

Browse files
author
Obliman
authored
Clarify multi_worker_with_keras.ipynb example code (#1)
Suggesting that the "mnist.py" file be renamed to something else like "mnist_setup.py" since the resulting "import mnist" is confusingly in conflict with the mnist package. This threw me off until I re-read the instructions, since I had placed the code in one file which resulted in errors executing mnist.mnist_dataset and mnist.build_and_compile_cnn_model.
1 parent 83f5d3c commit 2d9aa29

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

site/en/tutorials/distribute/multi_worker_with_keras.ipynb

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@
194194
"id": "fLW6D2TzvC-4"
195195
},
196196
"source": [
197-
"Next, create an `mnist.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:"
197+
"Next, create an `mnist_setup.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:"
198198
]
199199
},
200200
{
@@ -205,7 +205,7 @@
205205
},
206206
"outputs": [],
207207
"source": [
208-
"%%writefile mnist.py\n",
208+
"%%writefile mnist_setup.py\n",
209209
"\n",
210210
"import os\n",
211211
"import tensorflow as tf\n",
@@ -256,11 +256,11 @@
256256
},
257257
"outputs": [],
258258
"source": [
259-
"import mnist\n",
259+
"import mnist_setup\n",
260260
"\n",
261261
"batch_size = 64\n",
262-
"single_worker_dataset = mnist.mnist_dataset(batch_size)\n",
263-
"single_worker_model = mnist.build_and_compile_cnn_model()\n",
262+
"single_worker_dataset = mnist_setup.mnist_dataset(batch_size)\n",
263+
"single_worker_model = mnist_setup.build_and_compile_cnn_model()\n",
264264
"single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)"
265265
]
266266
},
@@ -492,7 +492,7 @@
492492
"source": [
493493
"with strategy.scope():\n",
494494
" # Model building/compiling need to be within `strategy.scope()`.\n",
495-
" multi_worker_model = mnist.build_and_compile_cnn_model()"
495+
" multi_worker_model = mnist_setup.build_and_compile_cnn_model()"
496496
]
497497
},
498498
{
@@ -512,7 +512,7 @@
512512
"source": [
513513
"To actually run with `MultiWorkerMirroredStrategy` you'll need to run worker processes and pass a `TF_CONFIG` to them.\n",
514514
"\n",
515-
"Like the `mnist.py` file written earlier, here is the `main.py` that each of the workers will run:"
515+
"Like the `mnist_setup.py` file written earlier, here is the `main.py` that each of the workers will run:"
516516
]
517517
},
518518
{
@@ -529,7 +529,7 @@
529529
"import json\n",
530530
"\n",
531531
"import tensorflow as tf\n",
532-
"import mnist\n",
532+
"import mnist_setup\n",
533533
"\n",
534534
"per_worker_batch_size = 64\n",
535535
"tf_config = json.loads(os.environ['TF_CONFIG'])\n",
@@ -538,11 +538,11 @@
538538
"strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
539539
"\n",
540540
"global_batch_size = per_worker_batch_size * num_workers\n",
541-
"multi_worker_dataset = mnist.mnist_dataset(global_batch_size)\n",
541+
"multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)\n",
542542
"\n",
543543
"with strategy.scope():\n",
544544
" # Model building/compiling need to be within `strategy.scope()`.\n",
545-
" multi_worker_model = mnist.build_and_compile_cnn_model()\n",
545+
" multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n",
546546
"\n",
547547
"\n",
548548
"multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)"
@@ -820,7 +820,7 @@
820820
"options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n",
821821
"\n",
822822
"global_batch_size = 64\n",
823-
"multi_worker_dataset = mnist.mnist_dataset(batch_size=64)\n",
823+
"multi_worker_dataset = mnist_setup.mnist_dataset(batch_size=64)\n",
824824
"dataset_no_auto_shard = multi_worker_dataset.with_options(options)"
825825
]
826826
},
@@ -1146,7 +1146,7 @@
11461146
"\n",
11471147
"callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]\n",
11481148
"with strategy.scope():\n",
1149-
" multi_worker_model = mnist.build_and_compile_cnn_model()\n",
1149+
" multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n",
11501150
"multi_worker_model.fit(multi_worker_dataset,\n",
11511151
" epochs=3,\n",
11521152
" steps_per_epoch=70,\n",

0 commit comments

Comments
 (0)