Skip to content

Commit a144c25

Browse files
authored
Add Keras tutorial for training neural networks with Xbatcher to user guide (#260)
1 parent ece38ba commit a144c25

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed

ci/requirements/doc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ dependencies:
2424
- pooch
2525
- zarr
2626
- pytorch
27+
- keras
28+
- tensorflow
2729
# Editable xbatcher installation
2830
- pip
2931

doc/user-guide/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ User Guide
77

88
caching
99
training-a-neural-network-with-Pytorch-and-xbatcher
10+
training-a-neural-network-with-keras-and-xbatcher
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "b314e777-7ffb-4e62-b4c5-ce8a785c5181",
6+
"metadata": {},
7+
"source": [
8+
"# End-to-End Tutorial: Training a Neural Network with Keras and Xbatcher\n",
9+
"\n",
10+
"## Import Required Libraries"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "5d912ff0-d808-4704-8dea-b9e1b5a53bf1",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"import matplotlib.pyplot as plt\n",
21+
"import tensorflow as tf\n",
22+
"import xarray as xr\n",
23+
"from keras import layers, models, optimizers\n",
24+
"\n",
25+
"import xbatcher as xb\n",
26+
"import xbatcher.loaders.keras"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"id": "7fb892c1-50fd-48c8-8567-b150946b53c9",
33+
"metadata": {},
34+
"outputs": [],
35+
"source": [
36+
"# Open the dataset stored in Zarr format\n",
37+
"ds = xr.open_dataset(\n",
38+
" 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n",
39+
" engine='zarr',\n",
40+
" chunks={},\n",
41+
" backend_kwargs={'storage_options': {'anon': True}},\n",
42+
")"
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"id": "c98134fe-581f-412a-93e3-6b07b7706078",
48+
"metadata": {},
49+
"source": [
50+
"## Define Batch Generators"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"id": "c680ebd7-0310-4f40-91b5-e7cc1a59e853",
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"# Define batch generators for features (X) and labels (y)\n",
61+
"X_bgen = xb.BatchGenerator(\n",
62+
" ds['images'],\n",
63+
" input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n",
64+
" preload_batch=False, # Load each batch dynamically\n",
65+
")\n",
66+
"y_bgen = xb.BatchGenerator(\n",
67+
" ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n",
68+
")"
69+
]
70+
},
71+
{
72+
"cell_type": "markdown",
73+
"id": "91d63180-e3a6-49f7-a8e7-67b8b698b08c",
74+
"metadata": {},
75+
"source": [
76+
"## Map Batches to a Keras-Compatible Dataset"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"id": "d1195057-269b-44ba-a3e7-aeedaa4ba8df",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": [
86+
"# Use xbatcher's MapDataset to wrap the generators\n",
87+
"dataset = xbatcher.loaders.keras.CustomTFDataset(X_bgen, y_bgen)\n",
88+
"\n",
89+
"# Create a DataLoader using tf.data.Dataset\n",
90+
"train_dataloader = tf.data.Dataset.from_generator(\n",
91+
" lambda: iter(dataset),\n",
92+
" output_signature=(\n",
93+
" tf.TensorSpec(shape=(2000, 1, 28, 28), dtype=tf.float32), # Images\n",
94+
" tf.TensorSpec(shape=(2000,), dtype=tf.int64), # Labels\n",
95+
" ),\n",
96+
").prefetch(3) # Prefetch 3 batches to improve performance"
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": null,
102+
"id": "1892411c-ca17-4d7f-b76b-5b5decaa78c1",
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"## Visualize a Sample Batch"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"id": "133b24bc-e7bc-4734-ad0a-22a848dd204c",
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"# Extract a batch from the DataLoader\n",
117+
"for train_features, train_labels in train_dataloader.take(1):\n",
118+
" print(f'Feature batch shape: {train_features.shape}')\n",
119+
" print(f'Labels batch shape: {train_labels.shape}')\n",
120+
"\n",
121+
" img = train_features[0].numpy().squeeze() # Extract the first image\n",
122+
" label = train_labels[0].numpy()\n",
123+
" plt.imshow(img, cmap='gray')\n",
124+
" plt.title(f'Label: {label}')\n",
125+
" plt.show()\n",
126+
" break"
127+
]
128+
},
129+
{
130+
"cell_type": "markdown",
131+
"id": "1e5d6a66-1943-47da-be67-9b54d51defed",
132+
"metadata": {},
133+
"source": [
134+
"## Build a Simple Neural Network with Keras"
135+
]
136+
},
137+
{
138+
"cell_type": "code",
139+
"execution_count": null,
140+
"id": "8b0490e5-7ccc-47fe-90ec-d41a81c4eb20",
141+
"metadata": {},
142+
"outputs": [],
143+
"source": [
144+
"# Define a simple feedforward neural network\n",
145+
"model = models.Sequential(\n",
146+
" [\n",
147+
" layers.Flatten(input_shape=(1, 28, 28)), # Flatten input images\n",
148+
" layers.Dense(128, activation='relu'), # Fully connected layer with 128 units\n",
149+
" layers.Dense(10, activation='softmax'), # Output layer for 10 classes\n",
150+
" ]\n",
151+
")\n",
152+
"\n",
153+
"# Compile the model\n",
154+
"model.compile(\n",
155+
" optimizer=optimizers.Adam(learning_rate=0.001),\n",
156+
" loss='sparse_categorical_crossentropy',\n",
157+
" metrics=['accuracy'],\n",
158+
")\n",
159+
"\n",
160+
"# Display model summary\n",
161+
"model.summary()"
162+
]
163+
},
164+
{
165+
"cell_type": "markdown",
166+
"id": "838df9c6-0753-4120-a0e0-dcc1480416b4",
167+
"metadata": {},
168+
"source": [
169+
"## Train the Model "
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"id": "25e86eba-4d4e-47cc-a6a7-9f0be244b009",
176+
"metadata": {},
177+
"outputs": [],
178+
"source": [
179+
"%%time\n",
180+
"\n",
181+
"# Train the model for 5 epochs\n",
182+
"epochs = 5\n",
183+
"\n",
184+
"model.fit(\n",
185+
" train_dataloader, # Pass the DataLoader directly\n",
186+
" epochs=epochs,\n",
187+
" verbose=1, # Print progress during training\n",
188+
")"
189+
]
190+
},
191+
{
192+
"cell_type": "markdown",
193+
"id": "a0f4246c-6461-4e2a-a49d-df6c1ce770fc",
194+
"metadata": {},
195+
"source": [
196+
"## Visualize a Sample Prediction"
197+
]
198+
},
199+
{
200+
"cell_type": "code",
201+
"execution_count": null,
202+
"id": "9361cb65-3c0d-40d6-be5c-18b309626817",
203+
"metadata": {},
204+
"outputs": [],
205+
"source": [
206+
"# Visualize a prediction on a sample image\n",
207+
"for train_features, train_labels in train_dataloader.take(1):\n",
208+
" img = train_features[0].numpy().squeeze()\n",
209+
" label = train_labels[0].numpy()\n",
210+
" predicted_label = tf.argmax(model.predict(train_features[:1]), axis=1).numpy()[0]\n",
211+
"\n",
212+
" plt.imshow(img, cmap='gray')\n",
213+
" plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n",
214+
" plt.show()\n",
215+
" break"
216+
]
217+
},
218+
{
219+
"cell_type": "markdown",
220+
"id": "372d0e0a-1542-4aa0-b3b9-9fd4337459ba",
221+
"metadata": {},
222+
"source": [
223+
"## Key Highlights \n",
224+
"\n",
225+
"- **Dynamic Batching**: Xbatcher and the MapDataset class allow for dynamic loading of batches, which reduces memory usage and speeds up data processing.\n",
226+
"- **Prefetching**: The prefetch feature in `tf.data.Dataset` overlaps data loading with model training to minimize idle time.\n",
227+
"- **Compatibility**: The pipeline works seamlessly with `keras.Model.fit`, simplifying training workflows."
228+
]
229+
}
230+
],
231+
"metadata": {
232+
"kernelspec": {
233+
"display_name": "Python 3 (ipykernel)",
234+
"language": "python",
235+
"name": "python3"
236+
},
237+
"language_info": {
238+
"codemirror_mode": {
239+
"name": "ipython",
240+
"version": 3
241+
},
242+
"file_extension": ".py",
243+
"mimetype": "text/x-python",
244+
"name": "python",
245+
"nbconvert_exporter": "python",
246+
"pygments_lexer": "ipython3",
247+
"version": "3.11.9"
248+
}
249+
},
250+
"nbformat": 4,
251+
"nbformat_minor": 5
252+
}

0 commit comments

Comments
 (0)