Skip to content

Commit 3af1306

Browse files
author
Joe Hamman
authored
Merge pull request #25 from jhamman/loader/torch
Add pytorch dataloader
2 parents 802bbd5 + 8bcd870 commit 3af1306

14 files changed

+341
-19
lines changed

.pre-commit-config.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
- id: double-quote-string-fixer
1515

1616
- repo: https://github.com/psf/black
17-
rev: 21.12b0
17+
rev: 22.1.0
1818
hooks:
1919
- id: black
2020
args: ["--line-length", "80", "--skip-string-normalization"]
@@ -37,3 +37,16 @@ repos:
3737
hooks:
3838
- id: prettier
3939
language_version: system
40+
41+
- repo: https://github.com/pre-commit/mirrors-mypy
42+
rev: v0.931
43+
hooks:
44+
- id: mypy
45+
additional_dependencies: [
46+
# Type stubs
47+
types-setuptools,
48+
types-pkg_resources,
49+
# Dependencies that are typed
50+
numpy,
51+
xarray,
52+
]

conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
import pytest
23

34

dev-requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
pytest
2+
torch
3+
coverage
24
pytest-cov
35
adlfs
46
-r requirements.txt

doc/api.rst

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@ API reference
55

66
This page provides an auto-generated summary of Xbatcher's API.
77

8-
Core
9-
====
10-
11-
.. autoclass:: xbatcher.BatchGenerator
12-
:members:
13-
148
Dataset.batch and DataArray.batch
159
=================================
1610

@@ -22,3 +16,17 @@ Dataset.batch and DataArray.batch
2216

2317
Dataset.batch.generator
2418
DataArray.batch.generator
19+
20+
Core
21+
====
22+
23+
.. autoclass:: xbatcher.BatchGenerator
24+
:members:
25+
26+
Dataloaders
27+
===========
28+
.. autoclass:: xbatcher.loaders.torch.MapDataset
29+
:members:
30+
31+
.. autoclass:: xbatcher.loaders.torch.IterableDataset
32+
:members:

doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# All configuration values have a default; values that are commented out
1313
# serve to show the default.
1414

15+
# type: ignore
16+
1517
import os
1618
import sys
1719

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9
77

88
[isort]
99
known_first_party=xbatcher
10-
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
10+
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
1111
multi_line_output=3
1212
include_trailing_comma=True
1313
force_grid_wrap=0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python
2+
# type: ignore
23
import os
34

45
from setuptools import find_packages, setup

xbatcher/accessors.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,21 @@ def generator(self, *args, **kwargs):
2424
Keyword arguments to pass to the `BatchGenerator` constructor.
2525
'''
2626
return BatchGenerator(self._obj, *args, **kwargs)
27+
28+
29+
@xr.register_dataarray_accessor('torch')
30+
class TorchAccessor:
31+
def __init__(self, xarray_obj):
32+
self._obj = xarray_obj
33+
34+
def to_tensor(self):
35+
"""Convert this DataArray to a torch.Tensor"""
36+
import torch
37+
38+
return torch.tensor(self._obj.data)
39+
40+
def to_named_tensor(self):
41+
"""Convert this DataArray to a torch.Tensor with named dimensions"""
42+
import torch
43+
44+
return torch.tensor(self._obj.data, names=self._obj.dims)

xbatcher/generators.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
from collections import OrderedDict
5+
from typing import Any, Dict, Hashable, Iterator
56

67
import xarray as xr
78

@@ -99,12 +100,12 @@ class BatchGenerator:
99100

100101
def __init__(
101102
self,
102-
ds,
103-
input_dims,
104-
input_overlap={},
105-
batch_dims={},
106-
concat_input_dims=False,
107-
preload_batch=True,
103+
ds: xr.Dataset,
104+
input_dims: Dict[Hashable, int],
105+
input_overlap: Dict[Hashable, int] = {},
106+
batch_dims: Dict[Hashable, int] = {},
107+
concat_input_dims: bool = False,
108+
preload_batch: bool = True,
108109
):
109110

110111
self.ds = _as_xarray_dataset(ds)
@@ -115,7 +116,38 @@ def __init__(
115116
self.concat_input_dims = concat_input_dims
116117
self.preload_batch = preload_batch
117118

118-
def __iter__(self):
119+
self._batches: Dict[
120+
int, Any
121+
] = self._gen_batches() # dict cache for batches
122+
# in the future, we can make this a lru cache or similar thing (cachey?)
123+
124+
def __iter__(self) -> Iterator[xr.Dataset]:
125+
for batch in self._batches.values():
126+
yield batch
127+
128+
def __len__(self) -> int:
129+
return len(self._batches)
130+
131+
def __getitem__(self, idx: int) -> xr.Dataset:
132+
133+
if not isinstance(idx, int):
134+
raise NotImplementedError(
135+
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
136+
)
137+
138+
if idx < 0:
139+
idx = list(self._batches)[idx]
140+
141+
if idx in self._batches:
142+
return self._batches[idx]
143+
else:
144+
raise IndexError('list index out of range')
145+
146+
def _gen_batches(self) -> dict:
147+
# in the future, we will want to do the batch generation lazily
148+
# going the eager route for now is allowing me to fill out the loader api
149+
# but it is likely to perform poorly.
150+
batches = []
119151
for ds_batch in self._iterate_batch_dims(self.ds):
120152
if self.preload_batch:
121153
ds_batch.load()
@@ -130,15 +162,17 @@ def __iter__(self):
130162
]
131163
dsc = xr.concat(all_dsets, dim='input_batch')
132164
new_input_dims = [
133-
dim + new_dim_suffix for dim in self.input_dims
165+
str(dim) + new_dim_suffix for dim in self.input_dims
134166
]
135-
yield _maybe_stack_batch_dims(dsc, new_input_dims)
167+
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
136168
else:
137169
for ds_input in input_generator:
138-
yield _maybe_stack_batch_dims(
139-
ds_input, list(self.input_dims)
170+
batches.append(
171+
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
140172
)
141173

174+
return dict(zip(range(len(batches)), batches))
175+
142176
def _iterate_batch_dims(self, ds):
143177
return _iterate_through_dataset(ds, self.batch_dims)
144178

xbatcher/loaders/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)