|
1 | 1 | """Classes for iterating through xarray datarrays / datasets in batches."""
|
2 | 2 |
|
3 | 3 | import itertools
|
| 4 | +import json |
4 | 5 | import warnings
|
5 | 6 | from collections.abc import Hashable, Iterator, Sequence
|
6 | 7 | from operator import itemgetter
|
@@ -262,6 +263,49 @@ def _get_batch_in_range_per_batch(self, batch_multi_index):
|
262 | 263 | batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0)
|
263 | 264 | return batch_in_range_per_patch
|
264 | 265 |
|
| 266 | + def to_json(self): |
| 267 | + """ |
| 268 | + Dump the BatchSchema properties to a JSON file. |
| 269 | +
|
| 270 | + Returns |
| 271 | + ---------- |
| 272 | + out_json: str |
| 273 | + The JSON representation of the BatchSchema |
| 274 | + """ |
| 275 | + out_dict = {} |
| 276 | + out_dict['input_dims'] = self.input_dims |
| 277 | + out_dict['input_overlap'] = self.input_overlap |
| 278 | + out_dict['batch_dims'] = self.batch_dims |
| 279 | + out_dict['concat_input_dims'] = self.input_dims |
| 280 | + out_dict['preload_batch'] = self.preload_batch |
| 281 | + batch_selector_dict = {} |
| 282 | + for i in self.selectors.keys(): |
| 283 | + batch_selector_dict[i] = self.selectors[i] |
| 284 | + for member in batch_selector_dict[i]: |
| 285 | + out_member_dict = {} |
| 286 | + member_keys = [x for x in member.keys()] |
| 287 | + for member_key in member_keys: |
| 288 | + out_member_dict[member_key] = { |
| 289 | + 'start': member[member_key].start, |
| 290 | + 'stop': member[member_key].stop, |
| 291 | + 'step': member[member_key].step, |
| 292 | + } |
| 293 | + out_dict['selector'] = out_member_dict |
| 294 | + return json.dumps(out_dict) |
| 295 | + |
| 296 | + def to_file(self, out_file_name: str): |
| 297 | + """ |
| 298 | + Dumps the JSON representation of the BatchSchema object to a file. |
| 299 | +
|
| 300 | + Parameters |
| 301 | + ---------- |
| 302 | + out_file_name: str |
| 303 | + The path to the json file to write to. |
| 304 | + """ |
| 305 | + out_json = self.to_json() |
| 306 | + with open(out_file_name, mode='w') as out_file: |
| 307 | + out_file.write(out_json) |
| 308 | + |
265 | 309 |
|
266 | 310 | def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[slice]:
|
267 | 311 | # return a list of slices to chop up a single dimension
|
|
0 commit comments