Skip to content

Commit 52dddba

Browse files
committed
Checking args, added tests
1 parent d2c099b commit 52dddba

File tree

2 files changed

+71
-42
lines changed

2 files changed

+71
-42
lines changed

test/core/gridmapping/test_cfconv.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -204,24 +204,35 @@ def _gen_2d(self):
204204
class AddSpatialRefTest(S3Test):
205205

206206
def test_add_spatial_ref(self):
207-
original_dataset = new_cube(x_name='x',
208-
y_name='y',
209-
x_start=0,
210-
y_start=0,
211-
x_res=10,
212-
y_res=10,
213-
x_units='metres',
214-
y_units='metres',
215-
drop_bounds=True,
216-
width=100,
217-
height=100,
218-
variables=dict(A=1.3, B=8.3))
207+
self.assert_add_spatial_ref_ok(None, None)
208+
self.assert_add_spatial_ref_ok(None, ('cx', 'cy'))
209+
self.assert_add_spatial_ref_ok('crs', None)
210+
self.assert_add_spatial_ref_ok('crs', ('cx', 'cy'))
211+
212+
def assert_add_spatial_ref_ok(self, crs_var_name, xy_dim_names):
219213

220214
root = 'eurodatacube-test/xcube-eea'
221215
data_id = 'test.zarr'
222-
crs_var_name = 'spatial_ref'
223216
crs = pyproj.CRS.from_string("EPSG:3035")
224217

218+
if xy_dim_names:
219+
x_name, y_name = xy_dim_names
220+
else:
221+
x_name, y_name = 'x', 'y'
222+
223+
cube = new_cube(x_name=x_name,
224+
y_name=y_name,
225+
x_start=0,
226+
y_start=0,
227+
x_res=10,
228+
y_res=10,
229+
x_units='metres',
230+
y_units='metres',
231+
drop_bounds=True,
232+
width=100,
233+
height=100,
234+
variables=dict(A=1.3, B=8.3))
235+
225236
storage_options = dict(
226237
anon=False,
227238
client_kwargs=dict(
@@ -239,13 +250,13 @@ def test_add_spatial_ref(self):
239250
root=root,
240251
storage_options=storage_options)
241252

242-
data_store.write_data(original_dataset, data_id=data_id)
243-
opened_dataset = data_store.open_data(data_id)
244-
self.assertEqual({'A', 'B', 'x', 'y', 'time'},
245-
set(opened_dataset.variables))
253+
data_store.write_data(cube, data_id=data_id)
254+
cube = data_store.open_data(data_id)
255+
self.assertEqual({'A', 'B', 'time', x_name, y_name},
256+
set(cube.variables))
246257

247258
with self.assertRaises(ValueError) as cm:
248-
GridMapping.from_dataset(opened_dataset)
259+
GridMapping.from_dataset(cube)
249260
self.assertEqual(
250261
('cannot find any grid mapping in dataset',),
251262
cm.exception.args
@@ -254,21 +265,32 @@ def test_add_spatial_ref(self):
254265
path = f"{root}/{data_id}"
255266
group_store = fs.get_mapper(path, create=True)
256267

257-
add_spatial_ref(group_store,
258-
crs,
259-
crs_var_name=crs_var_name)
260-
261-
self.assertTrue(fs.exists(f"{path}/{crs_var_name}"))
262-
self.assertTrue(fs.exists(f"{path}/{crs_var_name}/.zarray"))
263-
self.assertTrue(fs.exists(f"{path}/{crs_var_name}/.zattrs"))
264-
265-
opened_dataset = data_store.open_data(data_id)
266-
self.assertEqual({'A', 'B', 'x', 'y', 'time', 'spatial_ref'},
267-
set(opened_dataset.variables))
268-
self.assertEqual(crs_var_name,
269-
opened_dataset.A.attrs.get('grid_mapping'))
270-
self.assertEqual(crs_var_name,
271-
opened_dataset.B.attrs.get('grid_mapping'))
272-
273-
gm = GridMapping.from_dataset(opened_dataset)
268+
expected_crs_var_name = crs_var_name or 'spatial_ref'
269+
270+
self.assertTrue(fs.exists(path))
271+
self.assertFalse(fs.exists(f"{path}/{expected_crs_var_name}"))
272+
self.assertFalse(fs.exists(f"{path}/{expected_crs_var_name}/.zarray"))
273+
self.assertFalse(fs.exists(f"{path}/{expected_crs_var_name}/.zattrs"))
274+
275+
kwargs = {}
276+
if crs_var_name is not None:
277+
kwargs.update(crs_var_name=crs_var_name)
278+
if xy_dim_names is not None:
279+
kwargs.update(xy_dim_names=xy_dim_names)
280+
add_spatial_ref(group_store, crs, **kwargs)
281+
282+
self.assertTrue(fs.exists(f"{path}/{expected_crs_var_name}"))
283+
self.assertTrue(fs.exists(f"{path}/{expected_crs_var_name}/.zarray"))
284+
self.assertTrue(fs.exists(f"{path}/{expected_crs_var_name}/.zattrs"))
285+
286+
cube = data_store.open_data(data_id)
287+
self.assertEqual({'A', 'B', 'time',
288+
x_name, y_name, expected_crs_var_name},
289+
set(cube.variables))
290+
self.assertEqual(expected_crs_var_name,
291+
cube.A.attrs.get('grid_mapping'))
292+
self.assertEqual(expected_crs_var_name,
293+
cube.B.attrs.get('grid_mapping'))
294+
295+
gm = GridMapping.from_dataset(cube)
274296
self.assertIn("LAEA Europe", gm.crs.srs)

xcube/core/gridmapping/cfconv.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# SOFTWARE.
2121

2222
import warnings
23+
from collections.abc import MutableMapping
2324
from typing import Optional, Dict, Any, Hashable, Union, Set, List, Tuple
2425

2526
import numpy as np
@@ -29,6 +30,7 @@
2930
import zarr.convenience
3031

3132
from xcube.core.schema import get_dataset_chunks
33+
from xcube.util.assertions import assert_instance
3234

3335

3436
class GridCoords:
@@ -290,36 +292,41 @@ def _find_dataset_tile_size(dataset: xr.Dataset,
290292
return None
291293

292294

293-
def add_spatial_ref(group_store: zarr.convenience.StoreLike,
295+
def add_spatial_ref(dataset_store: zarr.convenience.StoreLike,
294296
crs: pyproj.CRS,
295297
crs_var_name: Optional[str] = 'spatial_ref',
296298
xy_dim_names: Optional[Tuple[str, str]] = None):
297299
"""
298300
Helper function that allows adding a spatial reference to an
299301
existing Zarr dataset.
300302
301-
:param group_store: The dataset's existing Zarr store or path.
302-
:param crs: The coordinate reference system.
303+
:param dataset_store: The dataset's existing Zarr store or path.
304+
:param crs: The spatial coordinate reference system.
303305
:param crs_var_name: The name of the variable that will hold the
304306
spatial reference. Defaults to "spatial_ref".
305307
:param xy_dim_names: The names of the x and y dimensions.
306308
Defaults to ("x", "y").
307309
"""
310+
assert_instance(dataset_store, (MutableMapping, str), name='group_store')
311+
assert_instance(crs_var_name, str, name='crs_var_name')
312+
x_dim_name, y_dim_name = xy_dim_names or ('x', 'y')
313+
308314
spatial_attrs = crs.to_cf()
309315
spatial_attrs['_ARRAY_DIMENSIONS'] = [] # Required by xarray
310-
group = zarr.open(group_store, mode='r+')
316+
group = zarr.open(dataset_store, mode='r+')
311317
spatial_ref = group.array(crs_var_name,
312318
0,
313319
shape=(),
314320
dtype=np.uint8,
315321
fill_value=0)
316322
spatial_ref.attrs.update(**spatial_attrs)
317323

318-
yx_dim_names = list(reversed(xy_dim_names or ('x', 'y')))
319324
for item_name, item in group.items():
320325
if item_name != crs_var_name:
321326
dims = item.attrs.get('_ARRAY_DIMENSIONS')
322-
if dims and len(dims) >= 2 and dims[-2:] == yx_dim_names:
327+
if dims and len(dims) >= 2 \
328+
and dims[-2] == y_dim_name \
329+
and dims[-1] == x_dim_name:
323330
item.attrs['grid_mapping'] = crs_var_name
324331

325-
zarr.convenience.consolidate_metadata(group_store)
332+
zarr.convenience.consolidate_metadata(dataset_store)

0 commit comments

Comments
 (0)