Skip to content

Commit 5060c0e

Browse files
committed
more stuff + travis
1 parent 12d876f commit 5060c0e

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

.travis.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
language: python
2+
3+
# sudo false implies containerized builds, so we can use cacheing
4+
sudo: false
5+
6+
notifications:
7+
email: false
8+
9+
python:
10+
- 3.6
11+
12+
env:
13+
- CONDA_DEPS="pip flake8 pytest coverage pandas xarray dask" PIP_DEPS="codecov pytest-cov"
14+
15+
before_install:
16+
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
17+
wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh;
18+
else
19+
wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh;
20+
fi
21+
- bash miniconda.sh -b -f -p $HOME/miniconda
22+
- export PATH="$HOME/miniconda/bin:$PATH"
23+
- hash -r
24+
- conda config --set always_yes yes --set changeps1 no
25+
- conda update -q conda
26+
- conda info -a
27+
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION $CONDA_DEPS
28+
- source activate test-environment
29+
- travis_retry pip install $PIP_DEPS
30+
31+
install:
32+
- python setup.py install --record installed_files.txt
33+
34+
script:
35+
- py.test xbatcher --cov=xrft --cov-config .coveragerc --cov-report term-missing -v
36+
37+
after_success:
38+
- codecov

xbatcher/features.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Functions for transforming xarray datasets into features that can
2+
be input to machine learning libraries."""
3+
4+
def dataset_to_feature_dataframe(ds, coords_as_features=False):
5+
df = ds.to_dataframe()
6+
return df

xbatcher/generators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Classes for iterating through xarray datarrays / datasets in batches."""
2+
13
import xarray as xr
24
from collections import OrderedDict
35
import itertools
@@ -26,6 +28,7 @@ def __iter__(self):
2628
for slices in itertools.product(*[self._iterate_dim(dim)
2729
for dim in self.batch_dims]):
2830
selector = {key: slice for key, slice in zip(self.batch_dims, slices)}
31+
#print(selector)
2932
yield self.ds.isel(**selector)
3033

3134

xbatcher/tests/test_generators.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
@pytest.fixture(scope='module')
88
def sample_ds_1d():
9-
size=100
9+
size = 100
1010
ds = xr.Dataset({'foo': (['x'], np.random.rand(size)),
1111
'bar': (['x'], np.random.randint(0, 10, size))},
1212
{'x': (['x'], np.arange(size))})
1313
return ds
1414

15+
1516
# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
1617
# Should we enforce that each batch size always has to be the same
1718
@pytest.mark.parametrize("bsize", [5, 10])
@@ -25,6 +26,7 @@ def test_batch_1d(sample_ds_1d, bsize):
2526
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
2627
assert ds_batch.equals(ds_batch_expected)
2728

29+
2830
@pytest.mark.parametrize("olap", [1, 4])
2931
def test_batch_1d_overlap(sample_ds_1d, olap):
3032
bsize = 10
@@ -37,3 +39,38 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
3739
expected_slice = slice(stride*n, stride*n + bsize)
3840
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
3941
assert ds_batch.equals(ds_batch_expected)
42+
43+
44+
@pytest.fixture(scope='module')
45+
def sample_ds_2d():
46+
shape = (50, 100)
47+
ds = xr.Dataset({'foo': (['y', 'x'], np.random.rand(*shape)),
48+
'bar': (['y', 'x'], np.random.randint(0, 10, shape))},
49+
{'x': (['x'], np.arange(shape[-1])),
50+
'y': (['y'], np.arange(shape[-2]))})
51+
return ds
52+
53+
54+
@pytest.mark.parametrize("bsize", [5, 10])
55+
def test_batch_2d(sample_ds_2d, bsize):
56+
57+
# first do the iteration over just one dimension
58+
bg = BatchGenerator(sample_ds_2d, batch_sizes={'x': bsize})
59+
for n, ds_batch in enumerate(bg):
60+
assert isinstance(ds_batch, xr.Dataset)
61+
assert ds_batch.dims['x'] == bsize
62+
assert ds_batch.dims['y'] == sample_ds_2d.dims['y']
63+
expected_slice = slice(bsize*n, bsize*(n+1))
64+
ds_batch_expected = sample_ds_2d.isel(x=expected_slice)
65+
assert ds_batch.equals(ds_batch_expected)
66+
67+
# now iterate over both
68+
xbsize = 20
69+
bg = BatchGenerator(sample_ds_2d, batch_sizes={'y': bsize, 'x': xbsize})
70+
for n, ds_batch in enumerate(bg):
71+
assert isinstance(ds_batch, xr.Dataset)
72+
assert ds_batch.dims['x'] == xbsize
73+
assert ds_batch.dims['y'] == bsize
74+
# TODO? Is it worth it to try to reproduce the internal logic of the
75+
# generator and verify that the slices are correct?
76+
assert (n+1)==((sample_ds_2d.dims['x']//xbsize) * (sample_ds_2d.dims['y']//bsize))

0 commit comments

Comments
 (0)