|
26 | 26 | "outputs": [],
|
27 | 27 | "source": [
|
28 | 28 | "import xarray as xr\n",
|
| 29 | + "\n", |
29 | 30 | "import xbatcher"
|
30 | 31 | ]
|
31 | 32 | },
|
|
46 | 47 | "metadata": {},
|
47 | 48 | "outputs": [],
|
48 | 49 | "source": [
|
49 |
| - "store = \"az://carbonplan-share/example_cmip6_data.zarr\"\n", |
| 50 | + "store = 'az://carbonplan-share/example_cmip6_data.zarr'\n", |
50 | 51 | "ds = xr.open_dataset(\n",
|
51 | 52 | " store,\n",
|
52 |
| - " engine=\"zarr\",\n", |
| 53 | + " engine='zarr',\n", |
53 | 54 | " chunks={},\n",
|
54 |
| - " backend_kwargs={\"storage_options\": {\"account_name\": \"carbonplan\"}},\n", |
| 55 | + " backend_kwargs={'storage_options': {'account_name': 'carbonplan'}},\n", |
55 | 56 | ")\n",
|
56 | 57 | "\n",
|
57 | 58 | "# the attributes contain a lot of useful information, but clutter the print out when we inspect the outputs\n",
|
|
98 | 99 | "\n",
|
99 | 100 | "bgen = xbatcher.BatchGenerator(\n",
|
100 | 101 | " ds=ds,\n",
|
101 |
| - " input_dims={\"time\": n_timepoint_in_each_sample},\n", |
| 102 | + " input_dims={'time': n_timepoint_in_each_sample},\n", |
102 | 103 | ")\n",
|
103 | 104 | "\n",
|
104 |
| - "print(f\"{len(bgen)} batches\")" |
| 105 | + "print(f'{len(bgen)} batches')" |
105 | 106 | ]
|
106 | 107 | },
|
107 | 108 | {
|
|
133 | 134 | "outputs": [],
|
134 | 135 | "source": [
|
135 | 136 | "expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n",
|
136 |
| - "print(f\"Expecting {expected_n_batch} batches, getting {len(bgen)} batches\")" |
| 137 | + "print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')" |
137 | 138 | ]
|
138 | 139 | },
|
139 | 140 | {
|
|
153 | 154 | "source": [
|
154 | 155 | "expected_batch_size = len(ds.lat) * len(ds.lon)\n",
|
155 | 156 | "print(\n",
|
156 |
| - " f\"Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch\"\n", |
| 157 | + " f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n", |
157 | 158 | ")"
|
158 | 159 | ]
|
159 | 160 | },
|
|
179 | 180 | "\n",
|
180 | 181 | "bgen = xbatcher.BatchGenerator(\n",
|
181 | 182 | " ds=ds,\n",
|
182 |
| - " input_dims={\"time\": n_timepoint_in_each_sample},\n", |
183 |
| - " batch_dims={\"time\": n_timepoint_in_each_batch},\n", |
| 183 | + " input_dims={'time': n_timepoint_in_each_sample},\n", |
| 184 | + " batch_dims={'time': n_timepoint_in_each_batch},\n", |
184 | 185 | " concat_input_dims=True,\n",
|
185 | 186 | ")\n",
|
186 | 187 | "\n",
|
187 |
| - "print(f\"{len(bgen)} batches\")" |
| 188 | + "print(f'{len(bgen)} batches')" |
188 | 189 | ]
|
189 | 190 | },
|
190 | 191 | {
|
|
217 | 218 | "source": [
|
218 | 219 | "n_timepoint_in_batch = 31\n",
|
219 | 220 | "\n",
|
220 |
| - "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={\"time\": n_timepoint_in_batch})\n", |
| 221 | + "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n", |
221 | 222 | "\n",
|
222 | 223 | "for batch in bgen:\n",
|
223 |
| - " print(f\"last time point in ds is {ds.time[-1].values}\")\n", |
224 |
| - " print(f\"last time point in batch is {batch.time[-1].values}\")\n", |
| 224 | + " print(f'last time point in ds is {ds.time[-1].values}')\n", |
| 225 | + " print(f'last time point in batch is {batch.time[-1].values}')\n", |
225 | 226 | "batch"
|
226 | 227 | ]
|
227 | 228 | },
|
|
249 | 250 | "\n",
|
250 | 251 | "bgen = xbatcher.BatchGenerator(\n",
|
251 | 252 | " ds=ds,\n",
|
252 |
| - " input_dims={\"time\": n_timepoint_in_each_sample},\n", |
253 |
| - " batch_dims={\"time\": n_timepoint_in_each_batch},\n", |
| 253 | + " input_dims={'time': n_timepoint_in_each_sample},\n", |
| 254 | + " batch_dims={'time': n_timepoint_in_each_batch},\n", |
254 | 255 | " concat_input_dims=True,\n",
|
255 |
| - " input_overlap={\"time\": input_overlap},\n", |
| 256 | + " input_overlap={'time': input_overlap},\n", |
256 | 257 | ")\n",
|
257 | 258 | "\n",
|
258 | 259 | "batch = bgen[0]\n",
|
259 | 260 | "\n",
|
260 |
| - "print(f\"{len(bgen)} batches\")\n", |
| 261 | + "print(f'{len(bgen)} batches')\n", |
261 | 262 | "batch"
|
262 | 263 | ]
|
263 | 264 | },
|
|
283 | 284 | "display(pixel)\n",
|
284 | 285 | "\n",
|
285 | 286 | "print(\n",
|
286 |
| - " f\"sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}\"\n", |
| 287 | + " f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n", |
287 | 288 | ")\n",
|
288 | 289 | "print(\n",
|
289 |
| - " f\"sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}\"\n", |
| 290 | + " f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n", |
290 | 291 | ")"
|
291 | 292 | ]
|
292 | 293 | },
|
|
310 | 311 | "outputs": [],
|
311 | 312 | "source": [
|
312 | 313 | "bgen = xbatcher.BatchGenerator(\n",
|
313 |
| - " ds=ds[[\"tasmax\"]].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", |
314 |
| - " input_dims={\"lat\": 9, \"lon\": 9, \"time\": 10},\n", |
315 |
| - " batch_dims={\"lat\": 18, \"lon\": 18, \"time\": 15},\n", |
| 314 | + " ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", |
| 315 | + " input_dims={'lat': 9, 'lon': 9, 'time': 10},\n", |
| 316 | + " batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n", |
316 | 317 | " concat_input_dims=True,\n",
|
317 |
| - " input_overlap={\"lat\": 8, \"lon\": 8, \"time\": 9},\n", |
| 318 | + " input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n", |
318 | 319 | ")\n",
|
319 | 320 | "\n",
|
320 | 321 | "for i, batch in enumerate(bgen):\n",
|
321 |
| - " print(f\"batch {i}\")\n", |
| 322 | + " print(f'batch {i}')\n", |
322 | 323 | " # make sure the ordering of dimension is consistent\n",
|
323 |
| - " batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n", |
| 324 | + " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", |
324 | 325 | "\n",
|
325 | 326 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
|
326 | 327 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n",
|
327 | 328 | " # select the center pixel at the last time point to be the label to be predicted\n",
|
328 | 329 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
|
329 | 330 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
|
330 | 331 | "\n",
|
331 |
| - " print(\"feature shape\", features.shape)\n", |
332 |
| - " print(\"label shape\", labels.shape)\n", |
333 |
| - " print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape)\n", |
334 |
| - " print(\"\")" |
| 332 | + " print('feature shape', features.shape)\n", |
| 333 | + " print('label shape', labels.shape)\n", |
| 334 | + " print('shape of lat of each sample', labels.coords['lat'].shape)\n", |
| 335 | + " print('')" |
335 | 336 | ]
|
336 | 337 | },
|
337 | 338 | {
|
|
350 | 351 | "outputs": [],
|
351 | 352 | "source": [
|
352 | 353 | "for i, batch in enumerate(bgen):\n",
|
353 |
| - " print(f\"batch {i}\")\n", |
| 354 | + " print(f'batch {i}')\n", |
354 | 355 | " # make sure the ordering of dimension is consistent\n",
|
355 |
| - " batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n", |
| 356 | + " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", |
356 | 357 | "\n",
|
357 | 358 | " # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
|
358 | 359 | " features = batch.tasmax.isel(time_input=slice(0, 9))\n",
|
359 |
| - " features = features.stack(features=[\"lat_input\", \"lon_input\", \"time_input\"])\n", |
| 360 | + " features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n", |
360 | 361 | "\n",
|
361 | 362 | " # select the center pixel at the last time point to be the label to be predicted\n",
|
362 | 363 | " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
|
363 | 364 | " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
|
364 | 365 | "\n",
|
365 |
| - " print(\"feature shape\", features.shape)\n", |
366 |
| - " print(\"label shape\", labels.shape)\n", |
367 |
| - " print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape, \"\\n\")" |
| 366 | + " print('feature shape', features.shape)\n", |
| 367 | + " print('label shape', labels.shape)\n", |
| 368 | + " print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')" |
368 | 369 | ]
|
369 | 370 | },
|
370 | 371 | {
|
|
0 commit comments