Skip to content

Commit ba11db5

Browse files
committed
add overwrite kwarg; resolves #71
1 parent 96c34bb commit ba11db5

File tree

6 files changed

+103
-73
lines changed

6 files changed

+103
-73
lines changed

docs/release.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
Release notes
22
=============
33

4+
* Added ``overwrite`` keyword argument to array and group creation methods
5+
on the :class:`zarr.hierarchy.Group` class.
6+
47
.. _release_2.0.1:
58

69
2.0.1

zarr/creation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def create(shape, chunks=None, dtype=None, compressor='default',
3838
synchronizer : object, optional
3939
Array synchronizer.
4040
overwrite : bool, optional
41-
If True, delete all pre-existing data in `store` before creating the
42-
array.
41+
If True, delete all pre-existing data in `store` at `path` before
42+
creating the array.
4343
path : string, optional
4444
Path under which array is stored.
4545
chunk_store : MutableMapping, optional

zarr/hierarchy.py

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from zarr.attrs import Attributes
1010
from zarr.core import Array
1111
from zarr.storage import contains_array, contains_group, init_group, \
12-
DictStore, DirectoryStore, group_meta_key, attrs_key, listdir
12+
DictStore, DirectoryStore, group_meta_key, attrs_key, listdir, rmdir
1313
from zarr.creation import array, create, empty, zeros, ones, full, \
1414
empty_like, zeros_like, ones_like, full_like
1515
from zarr.util import normalize_storage_path, normalize_shape
@@ -406,13 +406,15 @@ def _write_op(self, f, *args, **kwargs):
406406
with self._synchronizer[group_meta_key]:
407407
return f(*args, **kwargs)
408408

409-
def create_group(self, name):
409+
def create_group(self, name, overwrite=False):
410410
"""Create a sub-group.
411411
412412
Parameters
413413
----------
414414
name : string
415415
Group name.
416+
overwrite : bool, optional
417+
If True, overwrite any existing array with the given name.
416418
417419
Returns
418420
-------
@@ -428,35 +430,25 @@ def create_group(self, name):
428430
429431
"""
430432

431-
return self._write_op(self._create_group_nosync, name)
432-
433-
def _create_group_nosync(self, name):
433+
return self._write_op(self._create_group_nosync, name,
434+
overwrite=overwrite)
434435

436+
def _create_group_nosync(self, name, overwrite=False):
435437
path = self._item_path(name)
436438

437-
# require intermediate groups
438-
segments = path.split('/')
439-
for i in range(len(segments)):
440-
p = '/'.join(segments[:i])
441-
if contains_array(self._store, p):
442-
raise KeyError(name)
443-
elif not contains_group(self._store, p):
444-
init_group(self._store, path=p, chunk_store=self._chunk_store)
439+
# create intermediate groups
440+
self._require_parent_group(path, overwrite=overwrite)
445441

446442
# create terminal group
447-
if contains_array(self._store, path):
448-
raise KeyError(name)
449-
if contains_group(self._store, path):
450-
raise KeyError(name)
451-
else:
452-
init_group(self._store, path=path, chunk_store=self._chunk_store)
453-
return Group(self._store, path=path, read_only=self._read_only,
454-
chunk_store=self._chunk_store,
455-
synchronizer=self._synchronizer)
443+
init_group(self._store, path=path, chunk_store=self._chunk_store,
444+
overwrite=overwrite)
445+
return Group(self._store, path=path, read_only=self._read_only,
446+
chunk_store=self._chunk_store,
447+
synchronizer=self._synchronizer)
456448

457-
def create_groups(self, *names):
449+
def create_groups(self, *names, **kwargs):
458450
"""Convenience method to create multiple groups in a single call."""
459-
return tuple(self.create_group(name) for name in names)
451+
return tuple(self.create_group(name, **kwargs) for name in names)
460452

461453
def require_group(self, name):
462454
"""Obtain a sub-group, creating one if it doesn't exist.
@@ -504,18 +496,20 @@ def require_groups(self, *names):
504496
"""Convenience method to require multiple groups in a single call."""
505497
return tuple(self.require_group(name) for name in names)
506498

507-
def _require_parent_group(self, path):
499+
def _require_parent_group(self, path, overwrite=False):
508500
segments = path.split('/')
509501
for i in range(len(segments)):
510502
p = '/'.join(segments[:i])
511503
if contains_array(self._store, p):
512-
raise KeyError(path)
504+
init_group(self._store, path=p,
505+
chunk_store=self._chunk_store, overwrite=overwrite)
513506
elif not contains_group(self._store, p):
514507
init_group(self._store, path=p, chunk_store=self._chunk_store)
515508

516509
def create_dataset(self, name, data=None, shape=None, chunks=None,
517510
dtype=None, compressor='default', fill_value=None,
518-
order='C', synchronizer=None, filters=None, **kwargs):
511+
order='C', synchronizer=None, filters=None,
512+
overwrite=False, **kwargs):
519513
"""Create an array.
520514
521515
Parameters
@@ -540,7 +534,10 @@ def create_dataset(self, name, data=None, shape=None, chunks=None,
540534
synchronizer : zarr.sync.ArraySynchronizer, optional
541535
Array synchronizer.
542536
filters : sequence of Codecs, optional
543-
Sequence of filters to use to encode chunk data prior to compression.
537+
Sequence of filters to use to encode chunk data prior to
538+
compression.
539+
overwrite : bool, optional
540+
If True, replace any existing array or group with the given name.
544541
545542
Returns
546543
-------
@@ -564,21 +561,15 @@ def create_dataset(self, name, data=None, shape=None, chunks=None,
564561
shape=shape, chunks=chunks, dtype=dtype,
565562
compressor=compressor, fill_value=fill_value,
566563
order=order, synchronizer=synchronizer,
567-
filters=filters, **kwargs)
564+
filters=filters, overwrite=overwrite, **kwargs)
568565

569566
def _create_dataset_nosync(self, name, data=None, shape=None, chunks=None,
570567
dtype=None, compressor='default',
571568
fill_value=None, order='C', synchronizer=None,
572-
filters=None, **kwargs):
569+
filters=None, overwrite=False, **kwargs):
573570

574571
path = self._item_path(name)
575-
self._require_parent_group(path)
576-
577-
# guard conditions
578-
if contains_array(self._store, path):
579-
raise KeyError(name)
580-
if contains_group(self._store, path):
581-
raise KeyError(name)
572+
self._require_parent_group(path, overwrite=overwrite)
582573

583574
# determine synchronizer
584575
if synchronizer is None:
@@ -590,15 +581,16 @@ def _create_dataset_nosync(self, name, data=None, shape=None, chunks=None,
590581
compressor=compressor, fill_value=fill_value,
591582
order=order, synchronizer=synchronizer,
592583
store=self._store, path=path,
593-
chunk_store=self._chunk_store, filters=filters, **kwargs)
584+
chunk_store=self._chunk_store, filters=filters,
585+
overwrite=overwrite, **kwargs)
594586

595587
else:
596588
a = create(shape=shape, chunks=chunks, dtype=dtype,
597589
compressor=compressor, fill_value=fill_value,
598590
order=order, synchronizer=synchronizer,
599591
store=self._store, path=path,
600592
chunk_store=self._chunk_store, filters=filters,
601-
**kwargs)
593+
overwrite=overwrite, **kwargs)
602594

603595
return a
604596

@@ -655,7 +647,8 @@ def create(self, name, **kwargs):
655647

656648
def _create_nosync(self, name, **kwargs):
657649
path = self._item_path(name)
658-
self._require_parent_group(path)
650+
overwrite = kwargs.get('overwrite', False)
651+
self._require_parent_group(path, overwrite=overwrite)
659652
kwargs.setdefault('synchronizer', self._synchronizer)
660653
return create(store=self._store, path=path,
661654
chunk_store=self._chunk_store, **kwargs)
@@ -667,7 +660,8 @@ def empty(self, name, **kwargs):
667660

668661
def _empty_nosync(self, name, **kwargs):
669662
path = self._item_path(name)
670-
self._require_parent_group(path)
663+
overwrite = kwargs.get('overwrite', False)
664+
self._require_parent_group(path, overwrite=overwrite)
671665
kwargs.setdefault('synchronizer', self._synchronizer)
672666
return empty(store=self._store, path=path,
673667
chunk_store=self._chunk_store, **kwargs)
@@ -679,7 +673,8 @@ def zeros(self, name, **kwargs):
679673

680674
def _zeros_nosync(self, name, **kwargs):
681675
path = self._item_path(name)
682-
self._require_parent_group(path)
676+
overwrite = kwargs.get('overwrite', False)
677+
self._require_parent_group(path, overwrite=overwrite)
683678
kwargs.setdefault('synchronizer', self._synchronizer)
684679
return zeros(store=self._store, path=path,
685680
chunk_store=self._chunk_store, **kwargs)
@@ -691,7 +686,8 @@ def ones(self, name, **kwargs):
691686

692687
def _ones_nosync(self, name, **kwargs):
693688
path = self._item_path(name)
694-
self._require_parent_group(path)
689+
overwrite = kwargs.get('overwrite', False)
690+
self._require_parent_group(path, overwrite=overwrite)
695691
kwargs.setdefault('synchronizer', self._synchronizer)
696692
return ones(store=self._store, path=path,
697693
chunk_store=self._chunk_store, **kwargs)
@@ -703,7 +699,8 @@ def full(self, name, fill_value, **kwargs):
703699

704700
def _full_nosync(self, name, fill_value, **kwargs):
705701
path = self._item_path(name)
706-
self._require_parent_group(path)
702+
overwrite = kwargs.get('overwrite', False)
703+
self._require_parent_group(path, overwrite=overwrite)
707704
kwargs.setdefault('synchronizer', self._synchronizer)
708705
return full(store=self._store, path=path,
709706
chunk_store=self._chunk_store,
@@ -716,7 +713,8 @@ def array(self, name, data, **kwargs):
716713

717714
def _array_nosync(self, name, data, **kwargs):
718715
path = self._item_path(name)
719-
self._require_parent_group(path)
716+
overwrite = kwargs.get('overwrite', False)
717+
self._require_parent_group(path, overwrite=overwrite)
720718
kwargs.setdefault('synchronizer', self._synchronizer)
721719
return array(data, store=self._store, path=path,
722720
chunk_store=self._chunk_store, **kwargs)
@@ -728,7 +726,8 @@ def empty_like(self, name, data, **kwargs):
728726

729727
def _empty_like_nosync(self, name, data, **kwargs):
730728
path = self._item_path(name)
731-
self._require_parent_group(path)
729+
overwrite = kwargs.get('overwrite', False)
730+
self._require_parent_group(path, overwrite=overwrite)
732731
kwargs.setdefault('synchronizer', self._synchronizer)
733732
return empty_like(data, store=self._store, path=path,
734733
chunk_store=self._chunk_store, **kwargs)
@@ -740,7 +739,8 @@ def zeros_like(self, name, data, **kwargs):
740739

741740
def _zeros_like_nosync(self, name, data, **kwargs):
742741
path = self._item_path(name)
743-
self._require_parent_group(path)
742+
overwrite = kwargs.get('overwrite', False)
743+
self._require_parent_group(path, overwrite=overwrite)
744744
kwargs.setdefault('synchronizer', self._synchronizer)
745745
return zeros_like(data, store=self._store, path=path,
746746
chunk_store=self._chunk_store, **kwargs)
@@ -752,7 +752,8 @@ def ones_like(self, name, data, **kwargs):
752752

753753
def _ones_like_nosync(self, name, data, **kwargs):
754754
path = self._item_path(name)
755-
self._require_parent_group(path)
755+
overwrite = kwargs.get('overwrite', False)
756+
self._require_parent_group(path, overwrite=overwrite)
756757
kwargs.setdefault('synchronizer', self._synchronizer)
757758
return ones_like(data, store=self._store, path=path,
758759
chunk_store=self._chunk_store, **kwargs)
@@ -764,7 +765,8 @@ def full_like(self, name, data, **kwargs):
764765

765766
def _full_like_nosync(self, name, data, **kwargs):
766767
path = self._item_path(name)
767-
self._require_parent_group(path)
768+
overwrite = kwargs.get('overwrite', False)
769+
self._require_parent_group(path, overwrite=overwrite)
768770
kwargs.setdefault('synchronizer', self._synchronizer)
769771
return full_like(data, store=self._store, path=path,
770772
chunk_store=self._chunk_store, **kwargs)

zarr/storage.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def init_array(store, shape, chunks, dtype=None, compressor='default',
238238
if chunk_store is not None and chunk_store != store:
239239
rmdir(chunk_store, path)
240240
elif contains_array(store, path):
241-
raise ValueError('store contains an array')
241+
raise KeyError('path %r contains an array' % path)
242242
elif contains_group(store, path):
243-
raise ValueError('store contains a group')
243+
raise KeyError('path %r contains a group' % path)
244244

245245
# normalize metadata
246246
shape = normalize_shape(shape)
@@ -312,9 +312,9 @@ def init_group(store, overwrite=False, path=None, chunk_store=None):
312312
if chunk_store is not None and chunk_store != store:
313313
rmdir(chunk_store, path)
314314
elif contains_array(store, path):
315-
raise ValueError('store contains an array')
315+
raise KeyError('path %r contains an array' % path)
316316
elif contains_group(store, path):
317-
raise ValueError('store contains a group')
317+
raise KeyError('path %r contains a group' % path)
318318

319319
# initialize metadata
320320
# N.B., currently no metadata properties are needed, however there may
@@ -492,7 +492,7 @@ def getsize(self, path=None):
492492
parent, key = self._get_parent(path)
493493
value = parent[key]
494494
except KeyError:
495-
raise ValueError('path not found: %r' % path)
495+
raise KeyError('path %r not found' % path)
496496
else:
497497
value = self.root
498498

@@ -673,7 +673,7 @@ def getsize(self, path=None):
673673
size += os.path.getsize(child_fs_path)
674674
return size
675675
else:
676-
raise ValueError('path not found: %r' % path)
676+
raise KeyError('path %r not found' % path)
677677

678678

679679
# noinspection PyPep8Naming
@@ -806,7 +806,7 @@ def getsize(self, path=None):
806806
info = zf.getinfo(path)
807807
return info.compress_size
808808
except KeyError:
809-
raise ValueError('path not found: %r' % path)
809+
raise KeyError('path %r not found' % path)
810810
else:
811811
return 0
812812

zarr/tests/test_hierarchy.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,31 @@ def test_create_errors(self):
342342
with assert_raises(PermissionError):
343343
g.require_dataset('zzz', shape=100, chunks=10)
344344

345+
def test_create_overwrite(self):
346+
try:
347+
for method_name in 'create_dataset', 'create', 'empty', 'zeros', \
348+
'ones':
349+
g = self.create_group()
350+
getattr(g, method_name)('foo', shape=100, chunks=10)
351+
352+
# overwrite array with array
353+
d = getattr(g, method_name)('foo', shape=200, chunks=20,
354+
overwrite=True)
355+
eq((200,), d.shape)
356+
# overwrite array with group
357+
g2 = g.create_group('foo', overwrite=True)
358+
eq(0, len(g2))
359+
# overwrite group with array
360+
d = getattr(g, method_name)('foo', shape=300, chunks=30,
361+
overwrite=True)
362+
eq((300,), d.shape)
363+
# overwrite array with group
364+
d = getattr(g, method_name)('foo/bar', shape=400, chunks=40,
365+
overwrite=True)
366+
assert_is_instance(g['foo'], Group)
367+
except NotImplementedError:
368+
pass
369+
345370
def test_getitem_contains_iterators(self):
346371
# setup
347372
g1 = self.create_group()

0 commit comments

Comments
 (0)