Skip to content

Commit 12d876f

Browse files
committed
first commit
0 parents  commit 12d876f

File tree

5 files changed

+197
-0
lines changed

5 files changed

+197
-0
lines changed

.gitignore

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
5+
# C extensions
6+
*.so
7+
8+
# Distribution / packaging
9+
.Python
10+
env/
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
*.egg-info/
23+
.installed.cfg
24+
*.egg
25+
26+
# PyInstaller
27+
# Usually these files are written by a python script from a template
28+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
29+
*.manifest
30+
*.spec
31+
32+
# Installer logs
33+
pip-log.txt
34+
pip-delete-this-directory.txt
35+
36+
# Unit test / coverage reports
37+
htmlcov/
38+
.tox/
39+
.coverage
40+
.coverage.*
41+
.cache
42+
nosetests.xml
43+
coverage.xml
44+
*,cover
45+
46+
# Translations
47+
*.mo
48+
*.pot
49+
50+
# Django stuff:
51+
*.log
52+
53+
# Sphinx documentation
54+
docs/_build/
55+
56+
# PyBuilder
57+
target/
58+
59+
# notebook
60+
*/.ipynb_checkpoints/*
61+
62+
# tests
63+
.pytest_cache/*

setup.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python
2+
import os
3+
import re
4+
import sys
5+
import warnings
6+
7+
from setuptools import setup, find_packages
8+
9+
VERSION = '0.1.0'
10+
DISTNAME = 'xbatcher'
11+
LICENSE = 'Apache'
12+
AUTHOR = 'xbatcher Developers'
13+
AUTHOR_EMAIL = '[email protected]'
14+
URL = 'https://github.com/xgcm/xbatcher'
15+
CLASSIFIERS = [
16+
'Development Status :: 4 - Beta',
17+
'License :: OSI Approved :: Apache Software License',
18+
'Operating System :: OS Independent',
19+
'Intended Audience :: Science/Research',
20+
'Programming Language :: Python',
21+
'Programming Language :: Python :: 3',
22+
'Programming Language :: Python :: 3.5',
23+
'Programming Language :: Python :: 3.6',
24+
'Topic :: Scientific/Engineering',
25+
]
26+
27+
INSTALL_REQUIRES = ['xarray', 'dask', 'numpy']
28+
SETUP_REQUIRES = []
29+
TESTS_REQUIRE = ['pytest >= 2.8', 'coverage']
30+
31+
DESCRIPTION = "Batch generation from xarray dataset"
32+
def readme():
33+
return "TODO"
34+
#with open('README.rst') as f:
35+
# return f.read()
36+
37+
38+
setup(name=DISTNAME,
39+
version=VERSION,
40+
license=LICENSE,
41+
author=AUTHOR,
42+
author_email=AUTHOR_EMAIL,
43+
classifiers=CLASSIFIERS,
44+
description=DESCRIPTION,
45+
long_description=readme(),
46+
install_requires=INSTALL_REQUIRES,
47+
setup_requires=SETUP_REQUIRES,
48+
tests_require=TESTS_REQUIRE,
49+
url=URL,
50+
packages=find_packages())

xbatcher/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . generators import BatchGenerator

xbatcher/generators.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import xarray as xr
2+
from collections import OrderedDict
3+
import itertools
4+
5+
def _as_xarray_dataset(ds):
6+
# maybe coerce to xarray dataset
7+
if isinstance(ds, xr.Dataset):
8+
return ds
9+
else:
10+
return ds.to_dataset()
11+
12+
13+
class BatchGenerator:
14+
"""For iterating through xarray datarrays / datasets in batches."""
15+
16+
def __init__(self, ds, batch_sizes, overlap={}):
17+
self.ds = _as_xarray_dataset(ds)
18+
# should be a dict
19+
self.batch_sizes = OrderedDict(batch_sizes)
20+
self.batch_dims = list(self.batch_sizes)
21+
# make overlap is defined for each batch size defined
22+
self.overlap = {k: overlap.get(k, 0) for k in self.batch_dims}
23+
24+
25+
def __iter__(self):
26+
for slices in itertools.product(*[self._iterate_dim(dim)
27+
for dim in self.batch_dims]):
28+
selector = {key: slice for key, slice in zip(self.batch_dims, slices)}
29+
yield self.ds.isel(**selector)
30+
31+
32+
def _iterate_dim(self, dim):
33+
dimsize = self.ds.dims[dim]
34+
size = self.batch_sizes[dim]
35+
overlap = self.overlap[dim]
36+
stride = size - overlap
37+
assert stride > 0
38+
assert stride < dimsize
39+
for start in range(0, dimsize, stride):
40+
end = start+size
41+
if end <= dimsize:
42+
yield slice(start, end)
43+
else:
44+
return

xbatcher/tests/test_generators.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import xarray as xr
2+
import numpy as np
3+
from xbatcher import BatchGenerator
4+
import pytest
5+
6+
7+
@pytest.fixture(scope='module')
8+
def sample_ds_1d():
9+
size=100
10+
ds = xr.Dataset({'foo': (['x'], np.random.rand(size)),
11+
'bar': (['x'], np.random.randint(0, 10, size))},
12+
{'x': (['x'], np.arange(size))})
13+
return ds
14+
15+
# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
16+
# Should we enforce that each batch size always has to be the same
17+
@pytest.mark.parametrize("bsize", [5, 10])
18+
def test_batch_1d(sample_ds_1d, bsize):
19+
bg = BatchGenerator(sample_ds_1d, batch_sizes={'x': bsize})
20+
for n, ds_batch in enumerate(bg):
21+
assert isinstance(ds_batch, xr.Dataset)
22+
# TODO: maybe relax this? see comment above
23+
assert ds_batch.dims['x'] == bsize
24+
expected_slice = slice(bsize*n, bsize*(n+1))
25+
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
26+
assert ds_batch.equals(ds_batch_expected)
27+
28+
@pytest.mark.parametrize("olap", [1, 4])
29+
def test_batch_1d_overlap(sample_ds_1d, olap):
30+
bsize = 10
31+
bg = BatchGenerator(sample_ds_1d, batch_sizes={'x': bsize},
32+
overlap={'x': olap})
33+
stride = bsize-olap
34+
for n, ds_batch in enumerate(bg):
35+
assert isinstance(ds_batch, xr.Dataset)
36+
assert ds_batch.dims['x'] == bsize
37+
expected_slice = slice(stride*n, stride*n + bsize)
38+
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
39+
assert ds_batch.equals(ds_batch_expected)

0 commit comments

Comments
 (0)