Skip to content

Commit fd8d96a

Browse files
committed
refactor tests; add overwrite option
1 parent f46bb1c commit fd8d96a

File tree

2 files changed

+172
-88
lines changed

2 files changed

+172
-88
lines changed

zarr/convenience.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
510510

511511

512512
def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
513-
**create_kws):
513+
overwrite=False, **create_kws):
514514
"""Copy the `source` array or group into the `dest` group.
515515
516516
Parameters
@@ -527,6 +527,8 @@ def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
527527
Do not copy user attributes.
528528
log : callable, file path or file-like object, optional
529529
If provided, will be used to log progress information.
530+
overwrite : bool, optional
531+
If True, replace any objects in the destination.
530532
**create_kws
531533
Passed through to the create_dataset method when copying an array/dataset.
532534
@@ -562,10 +564,10 @@ def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
562564
# setup logging
563565
with _LogWriter(log) as log:
564566
_copy(log, source, dest, name=name, root=True, shallow=shallow,
565-
without_attrs=without_attrs, **create_kws)
567+
without_attrs=without_attrs, overwrite=overwrite, **create_kws)
566568

567569

568-
def _copy(log, source, dest, name, root, shallow, without_attrs, **create_kws):
570+
def _copy(log, source, dest, name, root, shallow, without_attrs, overwrite, **create_kws):
569571

570572
# are we copying to/from h5py?
571573
source_h5py = source.__module__.startswith('h5py.')
@@ -581,6 +583,15 @@ def _copy(log, source, dest, name, root, shallow, without_attrs, **create_kws):
581583
if hasattr(source, 'shape'):
582584
# copy a dataset/array
583585

586+
# check if already exists
587+
if name in dest:
588+
if overwrite:
589+
log('delete {}/{}'.format(dest.name, name))
590+
del dest[name]
591+
else:
592+
raise ValueError('object {!r} already exists in destination '
593+
'{!r}'.format(name, dest.name))
594+
584595
# setup creation keyword arguments
585596
kws = create_kws.copy()
586597

@@ -602,7 +613,7 @@ def _copy(log, source, dest, name, root, shallow, without_attrs, **create_kws):
602613
# zarr -> h5py; use some vaguely sensible defaults
603614
kws.setdefault('compression', 'gzip')
604615
kws.setdefault('compression_opts', 1)
605-
kws.setdefault('shuffle', True)
616+
kws.setdefault('shuffle', False)
606617
else:
607618
# zarr -> zarr; preserve compression options by default
608619
kws.setdefault('compressor', source.compressor)
@@ -621,8 +632,17 @@ def _copy(log, source, dest, name, root, shallow, without_attrs, **create_kws):
621632
elif root or not shallow:
622633
# copy a group
623634

624-
# creat new group in destination
625-
grp = dest.create_group(name)
635+
# check if an array is in the way
636+
if name in dest and hasattr(dest[name], 'shape'):
637+
if overwrite:
638+
log('delete {}/{}'.format(dest.name, name))
639+
del dest[name]
640+
else:
641+
raise ValueError('an array {!r} already exists in destination '
642+
'{!r}'.format(name, dest.name))
643+
644+
# require group in destination
645+
grp = dest.require_group(name)
626646
log('{} -> {}'.format(source.name, grp.name))
627647

628648
# copy attributes
@@ -632,7 +652,7 @@ def _copy(log, source, dest, name, root, shallow, without_attrs, **create_kws):
632652
# recurse
633653
for k in source.keys():
634654
_copy(log, source[k], grp, name=k, root=False, shallow=shallow,
635-
without_attrs=without_attrs, **create_kws)
655+
without_attrs=without_attrs, overwrite=overwrite, **create_kws)
636656

637657

638658
def tree(grp, expand=False, level=None):
@@ -685,7 +705,8 @@ def tree(grp, expand=False, level=None):
685705
return TreeViewer(grp, expand=expand, level=level)
686706

687707

688-
def copy_all(source, dest, shallow=False, without_attrs=False, log=None, **create_kws):
708+
def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
709+
overwrite=False, **create_kws):
689710
"""Copy all children of the `source` group into the `dest` group.
690711
691712
Parameters
@@ -700,6 +721,8 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None, **creat
700721
Do not copy user attributes.
701722
log : callable, file path or file-like object, optional
702723
If provided, will be used to log progress information.
724+
overwrite : bool, optional
725+
If True, replace any objects in the destination.
703726
**create_kws
704727
Passed through to the create_dataset method when copying an array/dataset.
705728
@@ -738,4 +761,4 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None, **creat
738761
with _LogWriter(log) as log:
739762
for k in source.keys():
740763
_copy(log, source[k], dest, name=k, root=False, shallow=shallow,
741-
without_attrs=without_attrs, **create_kws)
764+
without_attrs=without_attrs, overwrite=overwrite, **create_kws)

zarr/tests/test_convenience.py

Lines changed: 140 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from nose.tools import assert_raises
99
import numpy as np
1010
from numpy.testing import assert_array_equal
11-
from numcodecs import Zlib
11+
from numcodecs import Zlib, Adler32
1212
import pytest
1313

1414

@@ -182,111 +182,175 @@ def test_copy_store():
182182
assert 'bar/qux' in dest
183183

184184

185+
def check_copied_array(original, copied, without_attrs=False, expect_props=None):
186+
187+
# setup
188+
source_h5py = original.__module__.startswith('h5py.')
189+
dest_h5py = copied.__module__.startswith('h5py.')
190+
zarr_to_zarr = not (source_h5py or dest_h5py)
191+
h5py_to_h5py = source_h5py and dest_h5py
192+
zarr_to_h5py = not source_h5py and dest_h5py
193+
h5py_to_zarr = source_h5py and not dest_h5py
194+
if expect_props is None:
195+
expect_props = dict()
196+
else:
197+
expect_props = expect_props.copy()
198+
199+
# common properties in zarr and h5py
200+
for p in 'dtype', 'shape', 'chunks':
201+
expect_props.setdefault(p, getattr(original, p))
202+
203+
# zarr-specific properties
204+
if zarr_to_zarr:
205+
for p in 'compressor', 'filters', 'order', 'fill_value':
206+
expect_props.setdefault(p, getattr(original, p))
207+
208+
# h5py-specific properties
209+
if h5py_to_h5py:
210+
for p in ('maxshape', 'compression', 'compression_opts', 'shuffle',
211+
'scaleoffset', 'fletcher32', 'fillvalue'):
212+
expect_props.setdefault(p, getattr(original, p))
213+
214+
# common properties with some name differences
215+
if h5py_to_zarr:
216+
expect_props.setdefault('fill_value', original.fillvalue)
217+
if zarr_to_h5py:
218+
expect_props.setdefault('fillvalue', original.fill_value)
219+
220+
# compare properties
221+
for k, v in expect_props.items():
222+
assert v == getattr(copied, k)
223+
224+
# compare data
225+
assert_array_equal(original[:], copied[:])
226+
227+
# compare attrs
228+
if without_attrs:
229+
for k in original.attrs.keys():
230+
assert k not in copied.attrs
231+
else:
232+
assert sorted(original.attrs.items()) == sorted(copied.attrs.items())
233+
234+
235+
def check_copied_group(original, copied, without_attrs=False, expect_props=None,
236+
shallow=False):
237+
238+
# setup
239+
if expect_props is None:
240+
expect_props = dict()
241+
else:
242+
expect_props = expect_props.copy()
243+
244+
# compare children
245+
for k, v in original.items():
246+
if hasattr(v, 'shape'):
247+
assert k in copied
248+
check_copied_array(v, copied[k], without_attrs=without_attrs,
249+
expect_props=expect_props)
250+
elif shallow:
251+
assert k not in copied
252+
else:
253+
assert k in copied
254+
check_copied_group(v, copied[k], without_attrs=without_attrs,
255+
shallow=shallow, expect_props=expect_props)
256+
257+
# compare attrs
258+
if without_attrs:
259+
for k in original.attrs.keys():
260+
assert k not in copied.attrs
261+
else:
262+
assert sorted(original.attrs.items()) == sorted(copied.attrs.items())
263+
264+
185265
def _test_copy(new_source, new_dest):
186266

187267
source = new_source()
268+
dest = new_dest()
269+
# source_h5py = source.__module__.startswith('h5py.')
270+
dest_h5py = dest.__module__.startswith('h5py.')
271+
272+
# setup source
188273
foo = source.create_group('foo')
189274
foo.attrs['experiment'] = 'weird science'
190275
baz = foo.create_dataset('bar/baz', data=np.arange(100), chunks=(50,))
191276
baz.attrs['units'] = 'metres'
192-
spam = source.create_dataset('spam', data=np.arange(100, 200), chunks=(30,))
277+
source.create_dataset('spam', data=np.arange(100, 200), chunks=(30,))
193278

194279
# copy array with default options
195-
dest = new_dest()
196280
copy(source['foo/bar/baz'], dest)
197-
a = dest['baz'] # defaults to use source name
198-
assert a.dtype == baz.dtype
199-
assert a.shape == baz.shape
200-
assert a.chunks == baz.chunks
201-
if hasattr(a, 'compressor') and hasattr(baz, 'compressor'):
202-
assert a.compressor == baz.compressor
203-
assert_array_equal(a[:], baz[:])
204-
assert a.attrs['units'] == 'metres'
281+
check_copied_array(source['foo/bar/baz'], dest['baz'])
205282

206283
# copy array with name
207284
dest = new_dest()
208285
copy(source['foo/bar/baz'], dest, name='qux')
209286
assert 'baz' not in dest
210-
a = dest['qux']
211-
assert a.dtype == baz.dtype
212-
assert a.shape == baz.shape
213-
assert a.chunks == baz.chunks
214-
if hasattr(a, 'compressor') and hasattr(baz, 'compressor'):
215-
assert a.compressor == baz.compressor
216-
assert_array_equal(a[:], baz[:])
217-
assert a.attrs['units'] == 'metres'
287+
check_copied_array(source['foo/bar/baz'], dest['qux'])
218288

219289
# copy array, provide creation options
220290
dest = new_dest()
221291
compressor = Zlib(9)
222-
if isinstance(dest, Group):
223-
copy(source['foo/bar/baz'], dest, without_attrs=True, compressor=compressor,
224-
chunks=True)
225-
else:
226-
copy(source['foo/bar/baz'], dest, without_attrs=True, compression='gzip',
227-
compression_opts=9, chunks=True)
228-
a = dest['baz']
229-
assert a.dtype == baz.dtype
230-
assert a.shape == baz.shape
231-
assert a.chunks != baz.chunks # autochunking was requested
232-
if hasattr(a, 'compressor'):
233-
assert compressor == a.compressor
234-
if hasattr(baz, 'compressor'):
235-
assert a.compressor != baz.compressor
292+
create_kws = dict(chunks=(10,))
293+
if dest_h5py:
294+
create_kws.update(compression='gzip', compression_opts=9, shuffle=True,
295+
fletcher32=True, fillvalue=42)
236296
else:
237-
assert a.compression == 'gzip'
238-
assert a.compression_opts == 9
239-
assert_array_equal(a[:], baz[:])
240-
assert 'units' not in a.attrs
297+
create_kws.update(compressor=compressor, fill_value=42, order='F',
298+
filters=[Adler32()])
299+
copy(source['foo/bar/baz'], dest, without_attrs=True, **create_kws)
300+
check_copied_array(source['foo/bar/baz'], dest['baz'], without_attrs=True,
301+
expect_props=create_kws)
302+
303+
# copy array, dest array in the way
304+
dest = new_dest()
305+
dest.create_dataset('baz', shape=(10,))
306+
with pytest.raises(ValueError):
307+
copy(source['foo/bar/baz'], dest)
308+
assert (10,) == dest['baz'].shape
309+
copy(source['foo/bar/baz'], dest, overwrite=True)
310+
check_copied_array(source['foo/bar/baz'], dest['baz'])
311+
312+
# copy array, dest group in the way
313+
dest = new_dest()
314+
dest.create_group('baz')
315+
with pytest.raises(ValueError):
316+
copy(source['foo/bar/baz'], dest)
317+
assert not hasattr(dest['baz'], 'shape')
318+
copy(source['foo/bar/baz'], dest, overwrite=True)
319+
check_copied_array(source['foo/bar/baz'], dest['baz'])
241320

242321
# copy group, default options
243322
dest = new_dest()
244323
copy(source['foo'], dest)
245-
g = dest['foo'] # defaults to use source name
246-
assert g.attrs['experiment'] == 'weird science'
247-
a = g['bar/baz']
248-
assert a.dtype == baz.dtype
249-
assert a.shape == baz.shape
250-
assert a.chunks == baz.chunks
251-
if hasattr(a, 'compressor') and hasattr(baz, 'compressor'):
252-
assert a.compressor == baz.compressor
253-
assert_array_equal(a[:], baz[:])
254-
assert a.attrs['units'] == 'metres'
324+
check_copied_group(source['foo'], dest['foo'])
255325

256326
# copy group, non-default options
257327
dest = new_dest()
258328
copy(source['foo'], dest, name='qux', without_attrs=True)
259329
assert 'foo' not in dest
260-
g = dest['qux']
261-
assert 'experiment' not in g.attrs
262-
a = g['bar/baz']
263-
assert a.dtype == baz.dtype
264-
assert a.shape == baz.shape
265-
assert a.chunks == baz.chunks
266-
if hasattr(a, 'compressor') and hasattr(baz, 'compressor'):
267-
assert a.compressor == baz.compressor
268-
assert_array_equal(a[:], baz[:])
269-
assert 'units' not in a.attrs
330+
check_copied_group(source['foo'], dest['qux'], without_attrs=True)
270331

271332
# copy group, shallow
272333
dest = new_dest()
273334
copy(source, dest, name='eggs', shallow=True)
274-
assert 'eggs' in dest
275-
eggs = dest['eggs']
276-
assert 'spam' in eggs
277-
a = eggs['spam']
278-
assert a.dtype == spam.dtype
279-
assert a.shape == spam.shape
280-
assert a.chunks == spam.chunks
281-
if hasattr(a, 'compressor') and hasattr(spam, 'compressor'):
282-
assert a.compressor == spam.compressor
283-
assert_array_equal(a[:], spam[:])
284-
assert 'foo' not in eggs
285-
assert 'bar' not in eggs
286-
287-
288-
def test_copy_zarr_zarr():
289-
# zarr -> zarr
335+
check_copied_group(source, dest['eggs'], shallow=True)
336+
337+
# copy group, dest groups exist
338+
dest = new_dest()
339+
dest.create_group('foo/bar')
340+
copy(source['foo'], dest)
341+
check_copied_group(source['foo'], dest['foo'])
342+
343+
# copy group, dest array in the way
344+
dest = new_dest()
345+
dest.create_dataset('foo/bar', shape=(10,))
346+
with pytest.raises(ValueError):
347+
copy(source['foo'], dest)
348+
assert dest['foo/bar'].shape == (10,)
349+
copy(source['foo'], dest, overwrite=True)
350+
check_copied_group(source['foo'], dest['foo'])
351+
352+
353+
def test_copy_zarr_to_zarr():
290354
_test_copy(group, group)
291355

292356

@@ -305,18 +369,15 @@ def temp_h5f():
305369

306370

307371
@pytest.mark.skipif(not have_h5py, reason='h5py not installed')
308-
def test_copy_h5py_zarr():
309-
# h5py -> zarr
372+
def test_copy_h5py_to_zarr():
310373
_test_copy(temp_h5f, group)
311374

312375

313376
@pytest.mark.skipif(not have_h5py, reason='h5py not installed')
314-
def test_copy_zarr_h5py():
315-
# zarr -> h5py
377+
def test_copy_zarr_to_h5py():
316378
_test_copy(group, temp_h5f)
317379

318380

319381
@pytest.mark.skipif(not have_h5py, reason='h5py not installed')
320-
def test_copy_h5py_h5py():
321-
# zarr -> h5py
382+
def test_copy_h5py_to_h5py():
322383
_test_copy(temp_h5f, temp_h5f)

0 commit comments

Comments
 (0)