Skip to content

Commit 9f2fb20

Browse files
committed
rework copy_store to use if_exists/dry_run parameters
1 parent c51672f commit 9f2fb20

File tree

2 files changed

+125
-25
lines changed

2 files changed

+125
-25
lines changed

zarr/convenience.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from zarr.hierarchy import open_group, group as _create_group, Group
1616
from zarr.storage import contains_array, contains_group
1717
from zarr.errors import err_path_not_found
18-
from zarr.util import normalize_storage_path, TreeViewer
18+
from zarr.util import normalize_storage_path, TreeViewer, buffer_size
1919

2020

2121
# noinspection PyShadowingBuiltins
@@ -396,13 +396,14 @@ def __call__(self, *args, **kwargs):
396396

397397

398398
def copy_store(source, dest, source_path='', dest_path='', excludes=None,
399-
includes=None, flags=0, log=None):
399+
includes=None, flags=0, if_exists='raise', dry_run=False,
400+
log=None):
400401
"""Copy data directly from the `source` store to the `dest` store. Use this
401402
function when you want to copy a group or array in the most efficient way,
402403
preserving all configuration and attributes. This function is more efficient
403404
than the copy() or copy_all() functions because it avoids de-compressing and
404-
re-compressing data, rather the compressed chunk data for each array are copied
405-
directly between stores.
405+
re-compressing data, rather the compressed chunk data for each array are
406+
copied directly between stores.
406407
407408
Parameters
408409
----------
@@ -415,18 +416,27 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
415416
dest_path : str, optional
416417
Copy data into this path in the destination store.
417418
excludes : sequence of str, optional
418-
One or more regular expressions which will be matched against keys in the
419-
source store. Any matching key will not be copied.
419+
One or more regular expressions which will be matched against keys in
420+
the source store. Any matching key will not be copied.
420421
includes : sequence of str, optional
421-
One or more regular expressions which will be matched against keys in the
422-
source store and will override any excludes also matching.
422+
One or more regular expressions which will be matched against keys in
423+
the source store and will override any excludes also matching.
423424
flags : int, optional
424425
Regular expression flags used for matching excludes and includes.
426+
if_exists : {'raise', 'replace', 'skip'}, optional
427+
How to handle keys that already exist in the destination store. If
428+
'raise' then a ValueError is raised on the first key already present
429+
in the destination store. If 'replace' then any data will be replaced in
430+
the destination. If 'skip' then any existing keys will not be copied.
431+
dry_run : bool, optional
432+
If True, don't actually copy anything, just log what would have
433+
happened.
425434
log : callable, file path or file-like object, optional
426435
If provided, will be used to log progress information.
427436
428437
Examples
429438
--------
439+
430440
>>> import zarr
431441
>>> store1 = zarr.DirectoryStore('data/example.zarr')
432442
>>> root = zarr.group(store1, overwrite=True)
@@ -441,14 +451,15 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
441451
└── bar
442452
└── baz (100,) int64
443453
>>> import sys
444-
>>> store2 = zarr.ZipStore('data/example.zip', mode='w') # or any type of store
454+
>>> store2 = zarr.ZipStore('data/example.zip', mode='w')
445455
>>> zarr.copy_store(store1, store2, log=sys.stdout)
446-
.zgroup -> .zgroup
447-
foo/.zgroup -> foo/.zgroup
448-
foo/bar/.zgroup -> foo/bar/.zgroup
449-
foo/bar/baz/.zarray -> foo/bar/baz/.zarray
450-
foo/bar/baz/0 -> foo/bar/baz/0
451-
foo/bar/baz/1 -> foo/bar/baz/1
456+
copy .zgroup
457+
copy foo/.zgroup
458+
copy foo/bar/.zgroup
459+
copy foo/bar/baz/.zarray
460+
copy foo/bar/baz/0
461+
copy foo/bar/baz/1
462+
all done: 6 copy, 0 skip; 566 bytes copied
452463
>>> new_root = zarr.group(store2)
453464
>>> new_root.tree()
454465
/
@@ -481,6 +492,17 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
481492
excludes = [re.compile(e, flags) for e in excludes]
482493
includes = [re.compile(i, flags) for i in includes]
483494

495+
# check if_exists parameter
496+
valid_if_exists = ['raise', 'replace', 'skip']
497+
if if_exists not in valid_if_exists:
498+
raise ValueError('if_exists must be one of {!r}; found {!r}'
499+
.format(valid_if_exists, if_exists))
500+
501+
# setup counting variables
502+
n_copy = 0
503+
n_skip = 0
504+
n_bytes_copied = 0
505+
484506
# setup logging
485507
with _LogWriter(log) as log:
486508

@@ -508,9 +530,42 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
508530
key_suffix = source_key[len(source_path):]
509531
dest_key = dest_path + key_suffix
510532

511-
# retrieve and copy data
512-
log('{} -> {}'.format(source_key, dest_key))
513-
dest[dest_key] = source[source_key]
533+
# create a descriptive label for this operation
534+
descr = source_key
535+
if dest_key != source_key:
536+
descr = descr + ' -> ' + dest_key
537+
538+
# decide what to do
539+
do_copy = True
540+
if if_exists != 'replace':
541+
if dest_key in dest:
542+
if if_exists == 'raise':
543+
raise ValueError('key {!r} exists in destination'
544+
.format(dest_key))
545+
elif if_exists == 'skip':
546+
do_copy = False
547+
548+
# take action
549+
if do_copy:
550+
n_copy += 1
551+
log('copy {}'.format(descr))
552+
if not dry_run:
553+
data = source[source_key]
554+
n_bytes_copied += buffer_size(data)
555+
dest[dest_key] = data
556+
else:
557+
n_skip += 1
558+
log('skip {}'.format(descr))
559+
560+
# log a final message with a summary of what happened
561+
if dry_run:
562+
final_message = 'dry run: '
563+
else:
564+
final_message = 'all done: '
565+
final_message += '{} copy, {} skip'.format(n_copy, n_skip)
566+
if not dry_run:
567+
final_message += '; {:,} bytes copied'.format(n_bytes_copied)
568+
log(final_message)
514569

515570

516571
def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,

zarr/tests/test_convenience.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ def test_lazy_loader():
9090
assert_array_equal(bar, loader['bar'])
9191

9292

93-
def test_copy_store():
94-
95-
# no paths
93+
def test_copy_store_no_paths():
9694
source = dict()
9795
source['foo'] = b'xxx'
9896
source['bar'] = b'yyy'
@@ -102,7 +100,8 @@ def test_copy_store():
102100
for key in source:
103101
assert source[key] == dest[key]
104102

105-
# with source path
103+
104+
def test_copy_store_source_path():
106105
source = dict()
107106
source['foo'] = b'xxx'
108107
source['bar/baz'] = b'yyy'
@@ -119,7 +118,8 @@ def test_copy_store():
119118
else:
120119
assert key not in dest
121120

122-
# with dest path
121+
122+
def test_copy_store_dest_path():
123123
source = dict()
124124
source['foo'] = b'xxx'
125125
source['bar/baz'] = b'yyy'
@@ -133,7 +133,8 @@ def test_copy_store():
133133
dest_key = 'new/' + key
134134
assert source[key] == dest[dest_key]
135135

136-
# with source and dest path
136+
137+
def test_copy_store_source_dest_path():
137138
source = dict()
138139
source['foo'] = b'xxx'
139140
source['bar/baz'] = b'yyy'
@@ -152,7 +153,8 @@ def test_copy_store():
152153
assert key not in dest
153154
assert ('new/' + key) not in dest
154155

155-
# with excludes/includes
156+
157+
def test_copy_store_excludes_includes():
156158
source = dict()
157159
source['foo'] = b'xxx'
158160
source['bar/baz'] = b'yyy'
@@ -182,6 +184,49 @@ def test_copy_store():
182184
assert 'bar/qux' in dest
183185

184186

187+
def test_copy_store_dry_run():
188+
source = dict()
189+
source['foo'] = b'xxx'
190+
source['bar/baz'] = b'yyy'
191+
source['bar/qux'] = b'zzz'
192+
dest = dict()
193+
copy_store(source, dest, dry_run=True)
194+
assert 0 == len(dest)
195+
196+
197+
def test_copy_store_if_exists():
198+
199+
# setup
200+
source = dict()
201+
source['foo'] = b'xxx'
202+
source['bar/baz'] = b'yyy'
203+
source['bar/qux'] = b'zzz'
204+
dest = dict()
205+
dest['bar/baz'] = b'mmm'
206+
207+
# default ('raise')
208+
with pytest.raises(ValueError):
209+
copy_store(source, dest)
210+
211+
# explicit 'raise'
212+
with pytest.raises(ValueError):
213+
copy_store(source, dest, if_exists='raise')
214+
215+
# skip
216+
copy_store(source, dest, if_exists='skip')
217+
assert 3 == len(dest)
218+
assert dest['foo'] == b'xxx'
219+
assert dest['bar/baz'] == b'mmm'
220+
assert dest['bar/qux'] == b'zzz'
221+
222+
# replace
223+
copy_store(source, dest, if_exists='replace')
224+
assert 3 == len(dest)
225+
assert dest['foo'] == b'xxx'
226+
assert dest['bar/baz'] == b'yyy'
227+
assert dest['bar/qux'] == b'zzz'
228+
229+
185230
def check_copied_array(original, copied, without_attrs=False, expect_props=None):
186231

187232
# setup

0 commit comments

Comments
 (0)