Skip to content

Commit fd0244b

Browse files
rcjacksonRobert Jacksonandersy005pre-commit-ci[bot]
authored
Schema to json (#159)
Co-authored-by: Robert Jackson <[email protected]> Co-authored-by: Anderson Banihirwe <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <[email protected]>
1 parent 8e226e8 commit fd0244b

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

xbatcher/generators.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Classes for iterating through xarray datarrays / datasets in batches."""
22

33
import itertools
4+
import json
45
import warnings
56
from collections.abc import Hashable, Iterator, Sequence
67
from operator import itemgetter
@@ -262,6 +263,49 @@ def _get_batch_in_range_per_batch(self, batch_multi_index):
262263
batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0)
263264
return batch_in_range_per_patch
264265

266+
def to_json(self):
267+
"""
268+
Dump the BatchSchema properties to a JSON file.
269+
270+
Returns
271+
----------
272+
out_json: str
273+
The JSON representation of the BatchSchema
274+
"""
275+
out_dict = {}
276+
out_dict['input_dims'] = self.input_dims
277+
out_dict['input_overlap'] = self.input_overlap
278+
out_dict['batch_dims'] = self.batch_dims
279+
out_dict['concat_input_dims'] = self.input_dims
280+
out_dict['preload_batch'] = self.preload_batch
281+
batch_selector_dict = {}
282+
for i in self.selectors.keys():
283+
batch_selector_dict[i] = self.selectors[i]
284+
for member in batch_selector_dict[i]:
285+
out_member_dict = {}
286+
member_keys = [x for x in member.keys()]
287+
for member_key in member_keys:
288+
out_member_dict[member_key] = {
289+
'start': member[member_key].start,
290+
'stop': member[member_key].stop,
291+
'step': member[member_key].step,
292+
}
293+
out_dict['selector'] = out_member_dict
294+
return json.dumps(out_dict)
295+
296+
def to_file(self, out_file_name: str):
297+
"""
298+
Dumps the JSON representation of the BatchSchema object to a file.
299+
300+
Parameters
301+
----------
302+
out_file_name: str
303+
The path to the json file to write to.
304+
"""
305+
out_json = self.to_json()
306+
with open(out_file_name, mode='w') as out_file:
307+
out_file.write(out_json)
308+
265309

266310
def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[slice]:
267311
# return a list of slices to chop up a single dimension

xbatcher/tests/test_generators.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import json
2+
import tempfile
13
from typing import Any
24

35
import numpy as np
46
import pytest
57
import xarray as xr
68

7-
from xbatcher import BatchGenerator
9+
from xbatcher import BatchGenerator, BatchSchema
810
from xbatcher.testing import (
911
get_batch_dimensions,
1012
validate_batch_dimensions,
@@ -360,6 +362,21 @@ def test_input_overlap_exceptions(sample_ds_1d):
360362
assert len(e) == 1
361363

362364

365+
@pytest.mark.parametrize('input_size', [5, 10])
366+
def test_to_json(sample_ds_3d, input_size):
367+
x_input_size = 20
368+
bg = BatchSchema(
369+
sample_ds_3d,
370+
input_dims={'time': input_size, 'x': x_input_size},
371+
)
372+
out_file = tempfile.NamedTemporaryFile(mode='w+b')
373+
bg.to_file(out_file.name)
374+
in_dict = json.load(out_file)
375+
assert in_dict['input_dims']['time'] == input_size
376+
assert in_dict['input_dims']['x'] == x_input_size
377+
out_file.close()
378+
379+
363380
@pytest.mark.parametrize('preload', [True, False])
364381
def test_batcher_cached_getitem(sample_ds_1d, preload) -> None:
365382
pytest.importorskip('zarr')

0 commit comments

Comments
 (0)