Skip to content

Commit 95d9fe6

Browse files
author
Joseph Hamman
committed
exploring xarray accessors
1 parent e885373 commit 95d9fe6

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

xbatcher/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from . generators import BatchGenerator
1+
from .generators import BatchGenerator
2+
from .accessors import BatchAccessor

xbatcher/accessors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import xarray as xr
2+
3+
from .generators import BatchGenerator
4+
5+
6+
@xr.register_dataarray_accessor("batch")
7+
@xr.register_dataset_accessor("batch")
8+
class BatchAccessor:
9+
def __init__(self, xarray_obj):
10+
self._obj = xarray_obj
11+
12+
def generator(self, *args, **kwargs):
13+
return BatchGenerator(self._obj, *args, **kwargs)

xbatcher/tests/test_accessors.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import xarray as xr
2+
import numpy as np
3+
import pytest
4+
5+
from xbatcher import BatchGenerator
6+
import xbatcher
7+
8+
9+
@pytest.fixture(scope='module')
10+
def sample_ds_3d():
11+
shape = (10, 50, 100)
12+
ds = xr.Dataset({'foo': (['time', 'y', 'x'], np.random.rand(*shape)),
13+
'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape))},
14+
{'x': (['x'], np.arange(shape[-1])),
15+
'y': (['y'], np.arange(shape[-2]))})
16+
return ds
17+
18+
19+
def test_batch_accessor_ds(sample_ds_3d):
20+
bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5})
21+
bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5})
22+
assert isinstance(bg_acc, BatchGenerator)
23+
for batch_class, batch_acc in zip(bg_class, bg_acc):
24+
assert isinstance(batch_acc, xr.Dataset)
25+
assert batch_class.equals(batch_acc)
26+
27+
28+
def test_batch_accessor_da(sample_ds_3d):
29+
sample_da = sample_ds_3d['foo']
30+
bg_class = BatchGenerator(sample_da, input_dims={'x': 5})
31+
bg_acc = sample_da.batch.generator(input_dims={'x': 5})
32+
assert isinstance(bg_acc, BatchGenerator)
33+
for batch_class, batch_acc in zip(bg_class, bg_acc):
34+
assert batch_class.equals(batch_acc)

0 commit comments

Comments
 (0)