1+ from importlib import reload
2+
13import numpy as np
24import pytest
35import xarray as xr
46
57from xbatcher import BatchGenerator
6- from xbatcher .loaders .torch import IterableDataset , MapDataset
8+ from xbatcher .loaders .torch import IterableDataset , MapDataset , to_tensor
79
810torch = pytest .importorskip ('torch' )
911
1012
11- @pytest .fixture (scope = 'module' )
12- def ds_xy ():
13+ def test_import_torch_failure (monkeypatch ):
14+ import sys
15+
16+ import xbatcher .loaders
17+
18+ monkeypatch .setitem (sys .modules , 'torch' , None )
19+
20+ with pytest .raises (ImportError ) as excinfo :
21+ reload (xbatcher .loaders .torch )
22+
23+ assert 'install PyTorch to proceed' in str (excinfo .value )
24+
25+
26+ def test_import_dask_failure (monkeypatch ):
27+ import sys
28+
29+ import xbatcher .loaders
30+
31+ monkeypatch .setitem (sys .modules , 'dask' , None )
32+ reload (xbatcher .loaders .torch )
33+
34+ assert xbatcher .loaders .torch .dask is None
35+
36+
37+ @pytest .fixture (scope = 'module' , params = [True , False ])
38+ def ds_xy (request ):
1339 n_samples = 100
1440 n_features = 5
1541 ds = xr .Dataset (
@@ -21,17 +47,62 @@ def ds_xy():
2147 'y' : (['sample' ], np .random .random (n_samples )),
2248 },
2349 )
50+
51+ if request .param :
52+ ds = ds .chunk ({'sample' : 10 })
53+
2454 return ds
2555
2656
57+ @pytest .mark .parametrize ('x_var' , ['x' , ['x' ]])
58+ def test_map_dataset_without_y (ds_xy , x_var ) -> None :
59+ x = ds_xy [x_var ]
60+
61+ x_gen = BatchGenerator (x , {'sample' : 10 })
62+
63+ dataset = MapDataset (x_gen )
64+
65+ # test __getitem__
66+ x_batch = dataset [0 ]
67+ assert x_batch .shape == (10 , 5 ) # type: ignore[union-attr]
68+ assert isinstance (x_batch , torch .Tensor )
69+
70+ idx = torch .tensor ([0 ])
71+ x_batch = dataset [idx ]
72+ assert x_batch .shape == (10 , 5 )
73+ assert isinstance (x_batch , torch .Tensor )
74+
75+ with pytest .raises (NotImplementedError ):
76+ idx = torch .tensor ([0 , 1 ])
77+ x_batch = dataset [idx ]
78+
79+ # test __len__
80+ assert len (dataset ) == len (x_gen )
81+
82+ # test integration with torch DataLoader
83+ loader = torch .utils .data .DataLoader (dataset , batch_size = None )
84+
85+ for x_batch in loader :
86+ assert x_batch .shape == (10 , 5 ) # type: ignore[union-attr]
87+ assert isinstance (x_batch , torch .Tensor )
88+
89+ # Check that array shape of last item in generator is same as the batch image
90+ assert tuple (x_gen [- 1 ].sizes .values ()) == x_batch .shape # type: ignore[union-attr]
91+ # Check that array values from last item in generator and batch are the same
92+ gen_array = (
93+ x_gen [- 1 ].to_array ().squeeze () if hasattr (x_gen [- 1 ], 'to_array' ) else x_gen [- 1 ]
94+ )
95+ np .testing .assert_array_equal (gen_array , x_batch ) # type: ignore
96+
97+
2798@pytest .mark .parametrize (
2899 ('x_var' , 'y_var' ),
29100 [
30101 ('x' , 'y' ), # xr.DataArray
31102 (['x' ], ['y' ]), # xr.Dataset
32103 ],
33104)
34- def test_map_dataset (ds_xy , x_var , y_var ):
105+ def test_map_dataset (ds_xy , x_var , y_var ) -> None :
35106 x = ds_xy [x_var ]
36107 y = ds_xy [y_var ]
37108
@@ -73,7 +144,7 @@ def test_map_dataset(ds_xy, x_var, y_var):
73144 gen_array = (
74145 x_gen [- 1 ].to_array ().squeeze () if hasattr (x_gen [- 1 ], 'to_array' ) else x_gen [- 1 ]
75146 )
76- np .testing .assert_array_equal (gen_array , x_batch )
147+ np .testing .assert_array_equal (gen_array , x_batch ) # type: ignore
77148
78149
79150@pytest .mark .parametrize (
@@ -83,18 +154,18 @@ def test_map_dataset(ds_xy, x_var, y_var):
83154 (['x' ], ['y' ]), # xr.Dataset
84155 ],
85156)
86- def test_map_dataset_with_transform (ds_xy , x_var , y_var ):
157+ def test_map_dataset_with_transform (ds_xy , x_var , y_var ) -> None :
87158 x = ds_xy [x_var ]
88159 y = ds_xy [y_var ]
89160
90161 x_gen = BatchGenerator (x , {'sample' : 10 })
91162 y_gen = BatchGenerator (y , {'sample' : 10 })
92163
93164 def x_transform (batch ):
94- return batch * 0 + 1
165+ return to_tensor ( batch * 0 + 1 )
95166
96167 def y_transform (batch ):
97- return batch * 0 - 1
168+ return to_tensor ( batch * 0 - 1 )
98169
99170 dataset = MapDataset (
100171 x_gen , y_gen , transform = x_transform , target_transform = y_transform
0 commit comments