Skip to content

Commit bafe829

Browse files
SinaChavoshiTensorflow Cloud maintainers
authored andcommitted
Add sample for distributed training with MWMS.
PiperOrigin-RevId: 356563430
1 parent 3c76270 commit bafe829

File tree

4 files changed

+382
-6
lines changed

4 files changed

+382
-6
lines changed
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "PRgkg_PlTviL"
7+
},
8+
"source": [
9+
"# Distributed training NasNet with tensorflow_cloud and Google Cloud\n",
10+
"\n",
11+
"This example is based on [Image classification via fine-tuning with EfficientNet](https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/) to demonstrate how to train a [NasNetMobile](https://keras.io/api/applications/nasnet/#nasnetmobile-function) model using [tensorflow_cloud](https://github.com/tensorflow/cloud) and Google Cloud Platform at scale using distributed training.\n",
12+
"\n",
13+
"\u003ctable align=\"left\"\u003e\n",
14+
" \u003ctd\u003e\n",
15+
" \u003ca href=\"https://colab.research.google.com/github/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/examples/distributed_training_nasnet_with_tensorflow_cloud.ipynb\"\u003e\n",
16+
" \u003cimg width=\"50\" src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"\u003eRun in Colab\n",
17+
" \u003c/a\u003e\n",
18+
" \u003c/td\u003e\n",
19+
" \u003ctd\u003e\n",
20+
" \u003ca href=\"https://github.com/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/examples/distributed_training_nasnet_with_tensorflow_cloud.ipynb\"\u003e\n",
21+
" \u003cimg src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\"\u003eView on GitHub\n",
22+
" \u003c/a\u003e\n",
23+
" \u003c/td\u003e\n",
24+
" \u003ctd\u003e\n",
25+
" \u003ca href=\"https://www.kaggle.com/nitric/distributed-training-nasnet-with-tensorflow-cloud\"\u003e\n",
26+
" \u003cimg width=\"90\" src=\"https://www.kaggle.com/static/images/site-logo.png\" alt=\"Kaggle logo\"\u003eRun in Kaggle\n",
27+
" \u003c/a\u003e\n",
28+
" \u003c/td\u003e\n",
29+
"\u003c/table\u003e"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"metadata": {
36+
"id": "kA1D8jB3TviQ",
37+
"trusted": true
38+
},
39+
"outputs": [],
40+
"source": [
41+
"import os\n",
42+
"import sys\n",
43+
"import tensorflow as tf\n",
44+
"import tensorflow_cloud as tfc"
45+
]
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"metadata": {
50+
"id": "vdLn2dl2TviR"
51+
},
52+
"source": [
53+
"Set project parameters. For Google Cloud Specific parameters refer to [Google Cloud Project Setup Instructions](https://www.kaggle.com/nitric/google-cloud-project-setup-instructions/)."
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {
60+
"id": "b2Ev1lz-TviR",
61+
"trusted": true
62+
},
63+
"outputs": [],
64+
"source": [
65+
"# Set Google Cloud Specific parameters\n",
66+
"\n",
67+
"# TODO: Please set GCP_PROJECT_ID to your own Google Cloud project ID.\n",
68+
"GCP_PROJECT_ID = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n",
69+
"\n",
70+
"# TODO: set GCS_BUCKET to your own Google Cloud Storage (GCS) bucket.\n",
71+
"GCS_BUCKET = 'YOUR_GCS_BUCKET_NAME' #@param {type:\"string\"}\n",
72+
"\n",
73+
"# DO NOT CHANGE: Currently only the 'us-central1' region is supported.\n",
74+
"REGION = 'us-central1'\n",
75+
"\n",
76+
"# OPTIONAL: You can change the project name to any string.\n",
77+
"JOB_NAME = 'nasnet' #@param {type:\"string\"}\n",
78+
"\n",
79+
"# Setting location were training logs and checkpoints will be stored\n",
80+
"GCS_BASE_PATH = f'gs://{GCS_BUCKET}/{JOB_NAME}'\n",
81+
"TENSORBOARD_LOGS_DIR = os.path.join(GCS_BASE_PATH,\"logs\")\n",
82+
"MODEL_CHECKPOINT_DIR = os.path.join(GCS_BASE_PATH,\"checkpoints\")"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {
88+
"id": "KQ4B0XjaTviR"
89+
},
90+
"source": [
91+
"## Authenticating the notebook to use your Google Cloud Project\n",
92+
"\n",
93+
"For Kaggle Notebooks click on \"Add-ons\"-\u003e\"Google Cloud SDK\" before running the cell below."
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"metadata": {
100+
"id": "vZp9qc3STviS",
101+
"trusted": true
102+
},
103+
"outputs": [],
104+
"source": [
105+
"# Using tfc.remote() to ensure this code only runs in notebook\n",
106+
"if not tfc.remote():\n",
107+
"\n",
108+
" # Authentication for Kaggle Notebooks\n",
109+
" if \"kaggle_secrets\" in sys.modules:\n",
110+
" from kaggle_secrets import UserSecretsClient\n",
111+
" UserSecretsClient().set_gcloud_credentials(project=GCP_PROJECT_ID)\n",
112+
"\n",
113+
" # Authentication for Colab Notebooks\n",
114+
" if \"google.colab\" in sys.modules:\n",
115+
" from google.colab import auth\n",
116+
" auth.authenticate_user()"
117+
]
118+
},
119+
{
120+
"cell_type": "markdown",
121+
"metadata": {
122+
"id": "4Jix595FTviS"
123+
},
124+
"source": [
125+
"## Load and prepare data\n",
126+
"Read raw data and split to train and test data sets."
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"metadata": {
133+
"id": "5xEWEh2fTviS",
134+
"trusted": true
135+
},
136+
"outputs": [],
137+
"source": [
138+
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
139+
"\n",
140+
"# Setting input specific parameters\n",
141+
"# The model expects input of dimetions of (INPUT_IMG_SIZE, INPUT_IMG_SIZE, 3)\n",
142+
"INPUT_IMG_SIZE = 32\n",
143+
"NUM_CLASSES = 10"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"metadata": {
149+
"id": "69fNjNqWTviT"
150+
},
151+
"source": [
152+
" Add preprocessing layers APIs for image augmentation."
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": null,
158+
"metadata": {
159+
"id": "kstHXHtoTviT",
160+
"trusted": true
161+
},
162+
"outputs": [],
163+
"source": [
164+
"from tensorflow.keras.layers.experimental import preprocessing\n",
165+
"from tensorflow.keras.models import Sequential\n",
166+
"\n",
167+
"\n",
168+
"img_augmentation = Sequential(\n",
169+
" [\n",
170+
" # Resizing input to better match ImageNet size\n",
171+
" preprocessing.Resizing(256, 256),\n",
172+
" preprocessing.RandomRotation(factor=0.15),\n",
173+
" preprocessing.RandomFlip(),\n",
174+
" preprocessing.RandomContrast(factor=0.1),\n",
175+
" ],\n",
176+
" name=\"img_augmentation\",\n",
177+
")"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {
183+
"id": "QkYgwEBgTviU"
184+
},
185+
"source": [
186+
"## Load the model and prepare for training\n",
187+
"We will load a NASNetMobile pretrained model (with weights) and unfreeze a few layers for fine tuning the model to better match the dataset."
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"id": "NhL5g2YoTviU",
195+
"trusted": true
196+
},
197+
"outputs": [],
198+
"source": [
199+
"from tensorflow.keras import layers\n",
200+
"\n",
201+
"def build_model(num_classes, input_image_size):\n",
202+
" inputs = layers.Input(shape=(input_image_size, input_image_size, 3))\n",
203+
" x = img_augmentation(inputs)\n",
204+
"\n",
205+
" model = tf.keras.applications.NASNetMobile(\n",
206+
" input_shape=None,\n",
207+
" include_top=False,\n",
208+
" weights=\"imagenet\",\n",
209+
" input_tensor=x,\n",
210+
" pooling=None,\n",
211+
" classes=num_classes,\n",
212+
" )\n",
213+
"\n",
214+
" # Freeze the pretrained weights\n",
215+
" model.trainable = False\n",
216+
"\n",
217+
" # We unfreeze the top 20 layers while leaving BatchNorm layers frozen\n",
218+
" for layer in model.layers[-20:]:\n",
219+
" if not isinstance(layer, layers.BatchNormalization):\n",
220+
" layer.trainable = True\n",
221+
"\n",
222+
" # Rebuild top\n",
223+
" x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(model.output)\n",
224+
" x = layers.BatchNormalization()(x)\n",
225+
"\n",
226+
" x = layers.Dense(128, activation=\"relu\")(x)\n",
227+
" x = layers.Dense(64, activation=\"relu\")(x)\n",
228+
" outputs = layers.Dense(num_classes, activation=\"softmax\", name=\"pred\")(x)\n",
229+
"\n",
230+
" # Compile\n",
231+
" model = tf.keras.Model(inputs, outputs, name=\"NASNetMobile\")\n",
232+
" optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)\n",
233+
" model.compile(\n",
234+
" optimizer=optimizer,\n",
235+
" loss=\"sparse_categorical_crossentropy\",\n",
236+
" metrics=[\"accuracy\"]\n",
237+
" )\n",
238+
" return model"
239+
]
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": null,
244+
"metadata": {
245+
"id": "jQwNarnJTviU",
246+
"trusted": true
247+
},
248+
"outputs": [],
249+
"source": [
250+
"model = build_model(NUM_CLASSES, INPUT_IMG_SIZE)\n",
251+
"\n",
252+
"if tfc.remote():\n",
253+
" # Configure Tensorboard logs\n",
254+
" callbacks=[\n",
255+
" tf.keras.callbacks.TensorBoard(log_dir=TENSORBOARD_LOGS_DIR),\n",
256+
" tf.keras.callbacks.ModelCheckpoint(\n",
257+
" MODEL_CHECKPOINT_DIR,\n",
258+
" save_best_only=True),\n",
259+
" tf.keras.callbacks.EarlyStopping(\n",
260+
" monitor='loss',\n",
261+
" min_delta =0.001,\n",
262+
" patience=3)]\n",
263+
"\n",
264+
" model.fit(x=x_train, y=y_train, epochs=100,\n",
265+
" validation_split=0.2, callbacks=callbacks)\n",
266+
"\n",
267+
"else:\n",
268+
" # Run the training for 1 epoch and a small subset of the data to validate setup\n",
269+
" model.fit(x=x_train[:100], y=y_train[:100], validation_split=0.2, epochs=1)"
270+
]
271+
},
272+
{
273+
"cell_type": "markdown",
274+
"metadata": {
275+
"id": "44CHwtcPTviV"
276+
},
277+
"source": [
278+
"## Start the remote training\n",
279+
"\n",
280+
"This step will prepare your code from this notebook for remote execution and starts a distributed training remotely on Google Cloud Platfrom to train the model. Once the job is submitted you can go to the next step to monitor the jobs progress via Tensorboard.\n"
281+
]
282+
},
283+
{
284+
"cell_type": "code",
285+
"execution_count": null,
286+
"metadata": {
287+
"id": "I4gSaGXgTviV",
288+
"trusted": true
289+
},
290+
"outputs": [],
291+
"source": [
292+
"if not tfc.remote():\n",
293+
" print('Training on TensorFlow Cloud...')\n",
294+
"\n",
295+
" # If you are using a custom image you can install modules via requirements\n",
296+
" # txt file.\n",
297+
" with open('requirements.txt','w') as f:\n",
298+
" f.write('tensorflow-cloud==0.1.12\\n')\n",
299+
"\n",
300+
" # Optional: Some recommended base images. If you provide none the system\n",
301+
" # will choose one for you.\n",
302+
" TF_GPU_IMAGE= \"tensorflow/tensorflow:latest-gpu\"\n",
303+
" TF_CPU_IMAGE= \"tensorflow/tensorflow:latest\"\n",
304+
"\n",
305+
" tfc.run(\n",
306+
" distribution_strategy='auto',\n",
307+
" requirements_txt='requirements.txt',\n",
308+
" docker_config=tfc.DockerConfig(\n",
309+
" parent_image=TF_GPU_IMAGE,\n",
310+
" image_build_bucket=GCS_BUCKET\n",
311+
" ),\n",
312+
" chief_config=tfc.COMMON_MACHINE_CONFIGS['K80_1X'],\n",
313+
" worker_config=tfc.COMMON_MACHINE_CONFIGS['K80_1X'],\n",
314+
" worker_count=3,\n",
315+
" job_labels={'job': JOB_NAME}\n",
316+
" )"
317+
]
318+
},
319+
{
320+
"cell_type": "markdown",
321+
"metadata": {
322+
"id": "fCN-XJCRTviV"
323+
},
324+
"source": [
325+
"# Training Results\n",
326+
"While the training is in progress you can use Tensorboard to view the results."
327+
]
328+
},
329+
{
330+
"cell_type": "code",
331+
"execution_count": null,
332+
"metadata": {
333+
"id": "-dz-XpATTviV",
334+
"trusted": true
335+
},
336+
"outputs": [],
337+
"source": [
338+
"if not tfc.remote():\n",
339+
"\n",
340+
" %load_ext tensorboard\n",
341+
" %tensorboard --logdir TENSORBOARD_LOGS_DIR"
342+
]
343+
}
344+
],
345+
"metadata": {
346+
"colab": {
347+
"collapsed_sections": [],
348+
"name": "distributed-training-nasnet-with-tensorflow-cloud.ipynb",
349+
"provenance": [
350+
{
351+
"file_id": "1SRsTqmUqBJVWTyuteK7rmJtZthAn7Bth",
352+
"timestamp": 1612656396209
353+
}
354+
]
355+
},
356+
"kernelspec": {
357+
"display_name": "Python 3",
358+
"language": "python",
359+
"name": "python3"
360+
},
361+
"language_info": {
362+
"codemirror_mode": {
363+
"name": "ipython",
364+
"version": 3
365+
},
366+
"file_extension": ".py",
367+
"mimetype": "text/x-python",
368+
"name": "python",
369+
"nbconvert_exporter": "python",
370+
"pygments_lexer": "ipython3",
371+
"version": "3.7.9"
372+
}
373+
},
374+
"nbformat": 4,
375+
"nbformat_minor": 0
376+
}

src/python/tensorflow_cloud/examples/google_cloud_project_setup_instructions.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"\u003ctable align=\"left\"\u003e\n",
1313
" \u003ctd\u003e\n",
1414
" \u003ca href=\"https://colab.research.google.com/github/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/examples/google_cloud_project_setup_instructions.ipynb\"\u003e\n",
15-
" \u003cimg width=\"50\" src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"\u003e Run in Colab\n",
15+
" \u003cimg width=\"50\" src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"\u003eRun in Colab\n",
1616
" \u003c/a\u003e\n",
1717
" \u003c/td\u003e\n",
1818
" \u003ctd\u003e\n",
@@ -22,7 +22,7 @@
2222
" \u003c/td\u003e\n",
2323
" \u003ctd\u003e\n",
2424
" \u003ca href=\"https://www.kaggle.com/nitric/google-cloud-project-setup-instructions\"\u003e\n",
25-
" \u003cimg width=\"90\" src=\"https://www.kaggle.com/static/images/site-logo.png\" alt=\"Kaggle logo\"\u003eView on Kaggle\n",
25+
" \u003cimg width=\"90\" src=\"https://www.kaggle.com/static/images/site-logo.png\" alt=\"Kaggle logo\"\u003eRun in Kaggle\n",
2626
" \u003c/a\u003e\n",
2727
" \u003c/td\u003e\n",
2828
"\u003c/table\u003e\n"

0 commit comments

Comments
 (0)