Skip to content

Commit 5ac694e

Browse files
g-luoTensorflow Cloud maintainers
authored andcommitted
Update documentation for Distributed Training TensorFlow Cloud tutorial
PiperOrigin-RevId: 375808742
1 parent f88e880 commit 5ac694e

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

g3doc/tutorials/distributed_training_nasnet_with_tensorflow_cloud.ipynb

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@
8282
"cell_type": "code",
8383
"execution_count": null,
8484
"metadata": {
85-
"id": "Kw7YV8nN0gEb"
85+
"id": "Kw7YV8nN0gEb",
86+
"outputId": "7eed43ea-d108-4aa0-90b9-f52e7d7d8eab"
8687
},
8788
"outputs": [
8889
{
@@ -110,7 +111,8 @@
110111
"cell_type": "code",
111112
"execution_count": null,
112113
"metadata": {
113-
"id": "kA1D8jB3TviQ"
114+
"id": "kA1D8jB3TviQ",
115+
"outputId": "27c3a25f-ee92-47a1-edab-5e087e749d14"
114116
},
115117
"outputs": [
116118
{
@@ -136,7 +138,8 @@
136138
},
137139
"outputs": [],
138140
"source": [
139-
"import sys"
141+
"import sys\n",
142+
"import os"
140143
]
141144
},
142145
{
@@ -307,12 +310,12 @@
307310
" # Freeze the pretrained weights\n",
308311
" model.trainable = False\n",
309312
"\n",
310-
" # We unfreeze the top 20 layers while leaving BatchNorm layers frozen\n",
313+
" # Unfreeze the top 20 layers while leaving BatchNorm layers frozen\n",
311314
" for layer in model.layers[-20:]:\n",
312315
" if not isinstance(layer, layers.BatchNormalization):\n",
313316
" layer.trainable = True\n",
314317
"\n",
315-
" # Rebuild top\n",
318+
" # Rebuild non-frozen top layers of NASNetMobile, which was initialized with include_top=False\n",
316319
" x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(model.output)\n",
317320
" x = layers.BatchNormalization()(x)\n",
318321
"\n",
@@ -341,6 +344,7 @@
341344
"source": [
342345
"model = build_model(NUM_CLASSES, INPUT_IMG_SIZE)\n",
343346
"\n",
347+
"# Set up job on TensorFlow Cloud to train for 100 epochs on all the data\n",
344348
"if tfc.remote():\n",
345349
" # Configure Tensorboard logs\n",
346350
" callbacks=[\n",
@@ -359,7 +363,7 @@
359363
" model.save(SAVED_MODEL_DIR)\n",
360364
"\n",
361365
"else:\n",
362-
" # Run the training for 1 epoch and a small subset of the data to validate setup\n",
366+
" # Run the training locally for 1 epoch and a small subset of the data to validate setup\n",
363367
" model.fit(x=x_train[:100], y=y_train[:100], validation_split=0.2, epochs=1)"
364368
]
365369
},
@@ -382,8 +386,7 @@
382386
},
383387
"outputs": [],
384388
"source": [
385-
"# If you are using a custom image you can install modules via requirements\n",
386-
"# txt file.\n",
389+
"# You can install custom modules for your Docker images via requirements txt file.\n",
387390
"with open('requirements.txt','w') as f:\n",
388391
" f.write('tensorflow-cloud\\n')\n",
389392
"\n",
@@ -465,6 +468,7 @@
465468
"cIG5d4Kvls6m"
466469
],
467470
"name": "distributed_training_nasnet_with_tensorflow_cloud.ipynb",
471+
"provenance": [],
468472
"toc_visible": true
469473
},
470474
"kernelspec": {

g3doc/tutorials/hp_tuning_cifar10_using_google_cloud.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@
8383
"cell_type": "code",
8484
"execution_count": null,
8585
"metadata": {
86-
"id": "a751fAMsyu0q"
86+
"id": "a751fAMsyu0q",
87+
"outputId": "eb50e5e8-51ff-484f-c863-fba3ca7f6f49"
8788
},
8889
"outputs": [
8990
{
@@ -117,7 +118,8 @@
117118
"cell_type": "code",
118119
"execution_count": null,
119120
"metadata": {
120-
"id": "BnUf9-AA_pZ_"
121+
"id": "BnUf9-AA_pZ_",
122+
"outputId": "1bb6e925-0853-443c-9b8c-9ae90bae8b69"
121123
},
122124
"outputs": [
123125
{
@@ -406,7 +408,7 @@
406408
},
407409
"outputs": [],
408410
"source": [
409-
"# If you are using a custom image you can install modules via requirements txt file.\n",
411+
"# You can install custom modules for your Docker images via requirements txt file.\n",
410412
"with open('requirements.txt','w') as f:\n",
411413
" f.write('pandas==1.1.5\\n')\n",
412414
" f.write('numpy==1.18.5\\n')\n",
@@ -496,6 +498,7 @@
496498
"2ej8rnJkmoA8"
497499
],
498500
"name": "hp_tuning_cifar10_using_google_cloud.ipynb",
501+
"provenance": [],
499502
"toc_visible": true
500503
},
501504
"kernelspec": {

0 commit comments

Comments
 (0)