Skip to content

Commit f777b90

Browse files
committed
rework copy to use if_exists/dry_run parameters
1 parent 9f2fb20 commit f777b90

File tree

2 files changed

+207
-84
lines changed

2 files changed

+207
-84
lines changed

zarr/convenience.py

Lines changed: 204 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import io
66
import re
77
import itertools
8-
9-
10-
import numpy as np
8+
import operator
119

1210

1311
from zarr.core import Array
@@ -569,7 +567,7 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
569567

570568

571569
def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
572-
overwrite=False, **create_kws):
570+
if_exists='raise', dry_run=False, **create_kws):
573571
"""Copy the `source` array or group into the `dest` group.
574572
575573
Parameters
@@ -586,8 +584,17 @@ def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
586584
Do not copy user attributes.
587585
log : callable, file path or file-like object, optional
588586
If provided, will be used to log progress information.
589-
overwrite : bool, optional
590-
If True, replace any objects in the destination.
587+
if_exists : {'raise', 'replace', 'skip', 'skip_initialized'}, optional
588+
How to handle arrays that already exist in the destination group. If
589+
'raise' then a ValueError is raised on the first array already present
590+
in the destination group. If 'replace' then any array will be
591+
replaced in the destination. If 'skip' then any existing arrays will
592+
not be copied. If 'skip_initialized' then any existing arrays with
593+
all chunks initialized will not be copied (not available when copying to
594+
h5py).
595+
dry_run : bool, optional
596+
If True, don't actually copy anything, just log what would have
597+
happened.
591598
**create_kws
592599
Passed through to the create_dataset method when copying an array/dataset.
593600
@@ -609,9 +616,10 @@ def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
609616
>>> dest = zarr.group()
610617
>>> import sys
611618
>>> zarr.copy(source['foo'], dest, log=sys.stdout)
612-
/foo -> /foo
613-
/foo/bar -> /foo/bar
614-
/foo/bar/baz -> /foo/bar/baz
619+
copy /foo
620+
copy /foo/bar
621+
copy /foo/bar/baz (100,) int64
622+
all done: 3 copy, 0 skip; 800 bytes copied
615623
>>> dest.tree() # N.B., no spam
616624
/
617625
└── foo
@@ -622,16 +630,44 @@ def copy(source, dest, name=None, shallow=False, without_attrs=False, log=None,
622630

623631
# setup logging
624632
with _LogWriter(log) as log:
625-
_copy(log, source, dest, name=name, root=True, shallow=shallow,
626-
without_attrs=without_attrs, overwrite=overwrite, **create_kws)
627633

634+
# do the copying
635+
n_copy, n_skip, n_bytes_copied = _copy(
636+
log, source, dest, name=name, root=True, shallow=shallow,
637+
without_attrs=without_attrs, if_exists=if_exists, dry_run=dry_run,
638+
**create_kws
639+
)
628640

629-
def _copy(log, source, dest, name, root, shallow, without_attrs, overwrite, **create_kws):
641+
# log a final message with a summary of what was done
642+
if dry_run:
643+
final_message = 'dry run: '
644+
else:
645+
final_message = 'all done: '
646+
final_message += '{} copy, {} skip'.format(n_copy, n_skip)
647+
if not dry_run:
648+
final_message += '; {:,} bytes copied'.format(n_bytes_copied)
649+
log(final_message)
650+
651+
652+
def _copy(log, source, dest, name, root, shallow, without_attrs, if_exists,
653+
dry_run, **create_kws):
654+
655+
# setup counting variables
656+
n_copy = n_skip = n_bytes_copied = 0
630657

631658
# are we copying to/from h5py?
632659
source_h5py = source.__module__.startswith('h5py.')
633660
dest_h5py = dest.__module__.startswith('h5py.')
634661

662+
# check if_exists parameter
663+
valid_if_exists = ['raise', 'replace', 'skip', 'skip_initialized']
664+
if if_exists not in valid_if_exists:
665+
raise ValueError('if_exists must be one of {!r}; found {!r}'
666+
.format(valid_if_exists, if_exists))
667+
if dest_h5py and if_exists == 'skip_initialized':
668+
raise ValueError('{!r} can only be used then copying to zarr'
669+
.format(if_exists))
670+
635671
# determine name to copy to
636672
if name is None:
637673
name = source.name.split('/')[-1]
@@ -642,82 +678,139 @@ def _copy(log, source, dest, name, root, shallow, without_attrs, overwrite, **cr
642678
if hasattr(source, 'shape'):
643679
# copy a dataset/array
644680

645-
# check if already exists
646-
if name in dest:
647-
if overwrite:
648-
log('delete {}/{}'.format(dest.name, name))
649-
del dest[name]
650-
else:
681+
# check if already exists, decide what to do
682+
do_copy = True
683+
exists = name in dest
684+
if exists:
685+
if if_exists == 'raise':
651686
raise ValueError('an object {!r} already exists in destination '
652687
'{!r}'.format(name, dest.name))
688+
elif if_exists == 'skip':
689+
do_copy = False
690+
elif if_exists == 'skip_initialized':
691+
ds = dest[name]
692+
if ds.n_initialized == ds.n_chunks:
693+
do_copy = False
694+
695+
# log a message about what we're going to do
696+
if do_copy:
697+
n_copy += 1
698+
message = 'copy {} {} {}'.format(source.name, source.shape,
699+
source.dtype)
700+
else:
701+
n_skip += 1
702+
message = 'skip {} {} {}'.format(source.name, source.shape,
703+
source.dtype)
704+
log(message)
653705

654-
# setup creation keyword arguments
655-
kws = create_kws.copy()
656-
657-
# setup chunks option, preserve by default
658-
kws.setdefault('chunks', source.chunks)
706+
# take action
707+
if do_copy and not dry_run:
659708

660-
# setup compression options
661-
if source_h5py:
662-
if dest_h5py:
663-
# h5py -> h5py; preserve compression options by default
664-
kws.setdefault('compression', source.compression)
665-
kws.setdefault('compression_opts', source.compression_opts)
666-
kws.setdefault('shuffle', source.shuffle)
667-
else:
668-
# h5py -> zarr; use zarr default compression options
669-
pass
670-
else:
671-
if dest_h5py:
672-
# zarr -> h5py; use some vaguely sensible defaults
673-
kws.setdefault('chunks', True)
674-
kws.setdefault('compression', 'gzip')
675-
kws.setdefault('compression_opts', 1)
676-
kws.setdefault('shuffle', False)
677-
else:
678-
# zarr -> zarr; preserve compression options by default
679-
kws.setdefault('compressor', source.compressor)
709+
# clear the way
710+
if exists:
711+
del dest[name]
680712

681-
# create new dataset in destination
682-
ds = dest.create_dataset(name, shape=source.shape, dtype=source.dtype, **kws)
713+
# setup creation keyword arguments
714+
kws = create_kws.copy()
683715

684-
# copy data - N.B., go chunk by chunk to avoid loading everything into memory
685-
log('{} -> {}'.format(source.name, ds.name))
686-
shape = ds.shape
687-
chunks = ds.chunks
688-
chunk_offsets = [range(0, s, c) for s, c in zip(shape, chunks)]
689-
for offset in itertools.product(*chunk_offsets):
690-
sel = tuple(slice(o, min(s, o + c)) for o, s, c in zip(offset, shape, chunks))
691-
ds[sel] = source[sel]
716+
# setup chunks option, preserve by default
717+
kws.setdefault('chunks', source.chunks)
692718

693-
# copy attributes
694-
if not without_attrs:
695-
ds.attrs.update(source.attrs)
719+
# setup compression options
720+
if source_h5py:
721+
if dest_h5py:
722+
# h5py -> h5py; preserve compression options by default
723+
kws.setdefault('compression', source.compression)
724+
kws.setdefault('compression_opts', source.compression_opts)
725+
kws.setdefault('shuffle', source.shuffle)
726+
else:
727+
# h5py -> zarr; use zarr default compression options
728+
pass
729+
else:
730+
if dest_h5py:
731+
# zarr -> h5py; use some vaguely sensible defaults
732+
kws.setdefault('chunks', True)
733+
kws.setdefault('compression', 'gzip')
734+
kws.setdefault('compression_opts', 1)
735+
kws.setdefault('shuffle', False)
736+
else:
737+
# zarr -> zarr; preserve compression options by default
738+
kws.setdefault('compressor', source.compressor)
739+
740+
# create new dataset in destination
741+
ds = dest.create_dataset(name, shape=source.shape,
742+
dtype=source.dtype, **kws)
743+
744+
# copy data - N.B., go chunk by chunk to avoid loading everything
745+
# into memory
746+
shape = ds.shape
747+
chunks = ds.chunks
748+
chunk_offsets = [range(0, s, c) for s, c in zip(shape, chunks)]
749+
for offset in itertools.product(*chunk_offsets):
750+
sel = tuple(slice(o, min(s, o + c))
751+
for o, s, c in zip(offset, shape, chunks))
752+
ds[sel] = source[sel]
753+
n_bytes_copied += ds.size * ds.dtype.itemsize
754+
755+
# copy attributes
756+
if not without_attrs:
757+
ds.attrs.update(source.attrs)
696758

697759
elif root or not shallow:
698760
# copy a group
699761

700762
# check if an array is in the way
701-
if name in dest and hasattr(dest[name], 'shape'):
702-
if overwrite:
703-
log('delete {}/{}'.format(dest.name, name))
704-
del dest[name]
705-
else:
763+
exists_array = name in dest and hasattr(dest[name], 'shape')
764+
if exists_array:
765+
if if_exists == 'raise':
706766
raise ValueError('an array {!r} already exists in destination '
707767
'{!r}'.format(name, dest.name))
768+
elif if_exists == 'skip':
769+
n_skip += 1
770+
log('skip {}'.format(source.name))
771+
return
708772

709-
# require group in destination
710-
grp = dest.require_group(name)
711-
log('{} -> {}'.format(source.name, grp.name))
773+
# log action
774+
n_copy += 1
775+
log('copy {}'.format(source.name))
712776

713-
# copy attributes
714-
if not without_attrs:
715-
grp.attrs.update(source.attrs)
777+
if not dry_run:
716778

717-
# recurse
718-
for k in source.keys():
719-
_copy(log, source[k], grp, name=k, root=False, shallow=shallow,
720-
without_attrs=without_attrs, overwrite=overwrite, **create_kws)
779+
# clear the way
780+
if exists_array:
781+
del dest[name]
782+
783+
# require group in destination
784+
grp = dest.require_group(name)
785+
786+
# copy attributes
787+
if not without_attrs:
788+
grp.attrs.update(source.attrs)
789+
790+
# recurse
791+
for k in source.keys():
792+
c, s, b = _copy(
793+
log, source[k], grp, name=k, root=False, shallow=shallow,
794+
without_attrs=without_attrs, if_exists=if_exists,
795+
dry_run=dry_run, **create_kws)
796+
n_copy += c
797+
n_skip += s
798+
n_bytes_copied += b
799+
800+
elif name in dest:
801+
# dry run
802+
grp = dest[name]
803+
# recurse
804+
for k in source.keys():
805+
c, s, b = _copy(
806+
log, source[k], grp, name=k, root=False, shallow=shallow,
807+
without_attrs=without_attrs, if_exists=if_exists,
808+
dry_run=dry_run, **create_kws)
809+
n_copy += c
810+
n_skip += s
811+
n_bytes_copied += b
812+
813+
return n_copy, n_skip, n_bytes_copied
721814

722815

723816
def tree(grp, expand=False, level=None):
@@ -771,7 +864,7 @@ def tree(grp, expand=False, level=None):
771864

772865

773866
def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
774-
overwrite=False, **create_kws):
867+
if_exists='raise', dry_run=False, **create_kws):
775868
"""Copy all children of the `source` group into the `dest` group.
776869
777870
Parameters
@@ -786,10 +879,20 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
786879
Do not copy user attributes.
787880
log : callable, file path or file-like object, optional
788881
If provided, will be used to log progress information.
789-
overwrite : bool, optional
790-
If True, replace any objects in the destination.
882+
if_exists : {'raise', 'replace', 'skip', 'skip_initialized'}, optional
883+
How to handle arrays that already exist in the destination group. If
884+
'raise' then a ValueError is raised on the first array already present
885+
in the destination group. If 'replace' then any array will be
886+
replaced in the destination. If 'skip' then any existing arrays will
887+
not be copied. If 'skip_initialized' then any existing arrays with
888+
all chunks initialized will not be copied (not available when copying to
889+
h5py).
890+
dry_run : bool, optional
891+
If True, don't actually copy anything, just log what would have
892+
happened.
791893
**create_kws
792-
Passed through to the create_dataset method when copying an array/dataset.
894+
Passed through to the create_dataset method when copying an
895+
array/dataset.
793896
794897
Examples
795898
--------
@@ -809,10 +912,11 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
809912
>>> dest = zarr.group()
810913
>>> import sys
811914
>>> zarr.copy_all(source, dest, log=sys.stdout)
812-
/foo -> /foo
813-
/foo/bar -> /foo/bar
814-
/foo/bar/baz -> /foo/bar/baz
815-
/spam -> /spam
915+
copy /foo
916+
copy /foo/bar
917+
copy /foo/bar/baz (100,) int64
918+
copy /spam (100,) int64
919+
all done: 4 copy, 0 skip; 1,600 bytes copied
816920
>>> dest.tree()
817921
/
818922
├── foo
@@ -822,8 +926,27 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
822926
823927
"""
824928

929+
# setup counting variables
930+
n_copy = n_skip = n_bytes_copied = 0
931+
825932
# setup logging
826933
with _LogWriter(log) as log:
934+
827935
for k in source.keys():
828-
_copy(log, source[k], dest, name=k, root=False, shallow=shallow,
829-
without_attrs=without_attrs, overwrite=overwrite, **create_kws)
936+
c, s, b = _copy(
937+
log, source[k], dest, name=k, root=False, shallow=shallow,
938+
without_attrs=without_attrs, if_exists=if_exists,
939+
dry_run=dry_run, **create_kws)
940+
n_copy += c
941+
n_skip += s
942+
n_bytes_copied += b
943+
944+
# log a final message with a summary of what was done
945+
if dry_run:
946+
final_message = 'dry run: '
947+
else:
948+
final_message = 'all done: '
949+
final_message += '{} copy, {} skip'.format(n_copy, n_skip)
950+
if not dry_run:
951+
final_message += '; {:,} bytes copied'.format(n_bytes_copied)
952+
log(final_message)

0 commit comments

Comments
 (0)