Skip to content

Commit 0525e69

Browse files
committed
implement overwrite properly; resolves #71
1 parent ba11db5 commit 0525e69

File tree

6 files changed

+201
-213
lines changed

6 files changed

+201
-213
lines changed

zarr/creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from zarr.core import Array
1010
from zarr.storage import DirectoryStore, init_array, contains_array, \
11-
contains_group, default_compressor
11+
contains_group, default_compressor, _require_parent_group
1212
from zarr.codecs import codec_registry
1313

1414

zarr/hierarchy.py

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,10 @@ def create_group(self, name, overwrite=False):
436436
def _create_group_nosync(self, name, overwrite=False):
437437
path = self._item_path(name)
438438

439-
# create intermediate groups
440-
self._require_parent_group(path, overwrite=overwrite)
441-
442439
# create terminal group
443440
init_group(self._store, path=path, chunk_store=self._chunk_store,
444441
overwrite=overwrite)
442+
445443
return Group(self._store, path=path, read_only=self._read_only,
446444
chunk_store=self._chunk_store,
447445
synchronizer=self._synchronizer)
@@ -450,13 +448,15 @@ def create_groups(self, *names, **kwargs):
450448
"""Convenience method to create multiple groups in a single call."""
451449
return tuple(self.create_group(name, **kwargs) for name in names)
452450

453-
def require_group(self, name):
451+
def require_group(self, name, overwrite=False):
454452
"""Obtain a sub-group, creating one if it doesn't exist.
455453
456454
Parameters
457455
----------
458456
name : string
459457
Group name.
458+
overwrite : bool, optional
459+
Overwrite any existing array with given `name` if present.
460460
461461
Returns
462462
-------
@@ -473,20 +473,17 @@ def require_group(self, name):
473473
474474
"""
475475

476-
return self._write_op(self._require_group_nosync, name)
477-
478-
def _require_group_nosync(self, name):
476+
return self._write_op(self._require_group_nosync, name,
477+
overwrite=overwrite)
479478

479+
def _require_group_nosync(self, name, overwrite=False):
480480
path = self._item_path(name)
481481

482-
# require all intermediate groups
483-
segments = path.split('/')
484-
for i in range(len(segments) + 1):
485-
p = '/'.join(segments[:i])
486-
if contains_array(self._store, p):
487-
raise KeyError(name)
488-
elif not contains_group(self._store, p):
489-
init_group(self._store, path=p, chunk_store=self._chunk_store)
482+
# create terminal group if necessary
483+
if not contains_group(self._store, path):
484+
init_group(store=self._store, path=path,
485+
chunk_store=self._chunk_store,
486+
overwrite=overwrite)
490487

491488
return Group(self._store, path=path, read_only=self._read_only,
492489
chunk_store=self._chunk_store,
@@ -496,16 +493,6 @@ def require_groups(self, *names):
496493
"""Convenience method to require multiple groups in a single call."""
497494
return tuple(self.require_group(name) for name in names)
498495

499-
def _require_parent_group(self, path, overwrite=False):
500-
segments = path.split('/')
501-
for i in range(len(segments)):
502-
p = '/'.join(segments[:i])
503-
if contains_array(self._store, p):
504-
init_group(self._store, path=p,
505-
chunk_store=self._chunk_store, overwrite=overwrite)
506-
elif not contains_group(self._store, p):
507-
init_group(self._store, path=p, chunk_store=self._chunk_store)
508-
509496
def create_dataset(self, name, data=None, shape=None, chunks=None,
510497
dtype=None, compressor='default', fill_value=None,
511498
order='C', synchronizer=None, filters=None,
@@ -569,7 +556,6 @@ def _create_dataset_nosync(self, name, data=None, shape=None, chunks=None,
569556
filters=None, overwrite=False, **kwargs):
570557

571558
path = self._item_path(name)
572-
self._require_parent_group(path, overwrite=overwrite)
573559

574560
# determine synchronizer
575561
if synchronizer is None:
@@ -647,8 +633,6 @@ def create(self, name, **kwargs):
647633

648634
def _create_nosync(self, name, **kwargs):
649635
path = self._item_path(name)
650-
overwrite = kwargs.get('overwrite', False)
651-
self._require_parent_group(path, overwrite=overwrite)
652636
kwargs.setdefault('synchronizer', self._synchronizer)
653637
return create(store=self._store, path=path,
654638
chunk_store=self._chunk_store, **kwargs)
@@ -660,8 +644,6 @@ def empty(self, name, **kwargs):
660644

661645
def _empty_nosync(self, name, **kwargs):
662646
path = self._item_path(name)
663-
overwrite = kwargs.get('overwrite', False)
664-
self._require_parent_group(path, overwrite=overwrite)
665647
kwargs.setdefault('synchronizer', self._synchronizer)
666648
return empty(store=self._store, path=path,
667649
chunk_store=self._chunk_store, **kwargs)
@@ -673,8 +655,6 @@ def zeros(self, name, **kwargs):
673655

674656
def _zeros_nosync(self, name, **kwargs):
675657
path = self._item_path(name)
676-
overwrite = kwargs.get('overwrite', False)
677-
self._require_parent_group(path, overwrite=overwrite)
678658
kwargs.setdefault('synchronizer', self._synchronizer)
679659
return zeros(store=self._store, path=path,
680660
chunk_store=self._chunk_store, **kwargs)
@@ -686,8 +666,6 @@ def ones(self, name, **kwargs):
686666

687667
def _ones_nosync(self, name, **kwargs):
688668
path = self._item_path(name)
689-
overwrite = kwargs.get('overwrite', False)
690-
self._require_parent_group(path, overwrite=overwrite)
691669
kwargs.setdefault('synchronizer', self._synchronizer)
692670
return ones(store=self._store, path=path,
693671
chunk_store=self._chunk_store, **kwargs)
@@ -699,8 +677,6 @@ def full(self, name, fill_value, **kwargs):
699677

700678
def _full_nosync(self, name, fill_value, **kwargs):
701679
path = self._item_path(name)
702-
overwrite = kwargs.get('overwrite', False)
703-
self._require_parent_group(path, overwrite=overwrite)
704680
kwargs.setdefault('synchronizer', self._synchronizer)
705681
return full(store=self._store, path=path,
706682
chunk_store=self._chunk_store,
@@ -713,8 +689,6 @@ def array(self, name, data, **kwargs):
713689

714690
def _array_nosync(self, name, data, **kwargs):
715691
path = self._item_path(name)
716-
overwrite = kwargs.get('overwrite', False)
717-
self._require_parent_group(path, overwrite=overwrite)
718692
kwargs.setdefault('synchronizer', self._synchronizer)
719693
return array(data, store=self._store, path=path,
720694
chunk_store=self._chunk_store, **kwargs)
@@ -726,8 +700,6 @@ def empty_like(self, name, data, **kwargs):
726700

727701
def _empty_like_nosync(self, name, data, **kwargs):
728702
path = self._item_path(name)
729-
overwrite = kwargs.get('overwrite', False)
730-
self._require_parent_group(path, overwrite=overwrite)
731703
kwargs.setdefault('synchronizer', self._synchronizer)
732704
return empty_like(data, store=self._store, path=path,
733705
chunk_store=self._chunk_store, **kwargs)
@@ -739,8 +711,6 @@ def zeros_like(self, name, data, **kwargs):
739711

740712
def _zeros_like_nosync(self, name, data, **kwargs):
741713
path = self._item_path(name)
742-
overwrite = kwargs.get('overwrite', False)
743-
self._require_parent_group(path, overwrite=overwrite)
744714
kwargs.setdefault('synchronizer', self._synchronizer)
745715
return zeros_like(data, store=self._store, path=path,
746716
chunk_store=self._chunk_store, **kwargs)
@@ -752,8 +722,6 @@ def ones_like(self, name, data, **kwargs):
752722

753723
def _ones_like_nosync(self, name, data, **kwargs):
754724
path = self._item_path(name)
755-
overwrite = kwargs.get('overwrite', False)
756-
self._require_parent_group(path, overwrite=overwrite)
757725
kwargs.setdefault('synchronizer', self._synchronizer)
758726
return ones_like(data, store=self._store, path=path,
759727
chunk_store=self._chunk_store, **kwargs)
@@ -765,8 +733,6 @@ def full_like(self, name, data, **kwargs):
765733

766734
def _full_like_nosync(self, name, data, **kwargs):
767735
path = self._item_path(name)
768-
overwrite = kwargs.get('overwrite', False)
769-
self._require_parent_group(path, overwrite=overwrite)
770736
kwargs.setdefault('synchronizer', self._synchronizer)
771737
return full_like(data, store=self._store, path=path,
772738
chunk_store=self._chunk_store, **kwargs)
@@ -819,12 +785,8 @@ def group(store=None, overwrite=False, chunk_store=None, synchronizer=None):
819785
store = DictStore()
820786

821787
# require group
822-
if overwrite:
823-
init_group(store, overwrite=True, chunk_store=chunk_store)
824-
elif contains_array(store):
825-
raise ValueError('store contains an array')
826-
elif not contains_group(store):
827-
init_group(store, chunk_store=chunk_store)
788+
if overwrite or not contains_group(store):
789+
init_group(store, overwrite=overwrite, chunk_store=chunk_store)
828790

829791
return Group(store, read_only=False, chunk_store=chunk_store,
830792
synchronizer=synchronizer)
@@ -887,7 +849,7 @@ def open_group(path, mode='a', synchronizer=None):
887849
elif mode == 'a':
888850
if contains_array(store):
889851
raise ValueError('store contains array')
890-
elif not contains_group(store):
852+
if not contains_group(store):
891853
init_group(store)
892854

893855
elif mode in ['w-', 'x']:

zarr/storage.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,19 @@ def getsize(store, path=None):
122122
return -1
123123

124124

125+
def _require_parent_group(path, store, chunk_store, overwrite):
126+
path = normalize_storage_path(path)
127+
if path:
128+
segments = path.split('/')
129+
for i in range(len(segments)):
130+
p = '/'.join(segments[:i])
131+
if contains_array(store, p):
132+
_init_group_metadata(store, path=p, chunk_store=chunk_store,
133+
overwrite=overwrite)
134+
elif not contains_group(store, p):
135+
_init_group_metadata(store, path=p, chunk_store=chunk_store)
136+
137+
125138
def init_array(store, shape, chunks, dtype=None, compressor='default',
126139
fill_value=None, order='C', overwrite=False, path=None,
127140
chunk_store=None, filters=None):
@@ -195,11 +208,12 @@ def init_array(store, shape, chunks, dtype=None, compressor='default',
195208
196209
Initialize an array using a storage path::
197210
211+
>>> store = dict()
198212
>>> init_array(store, shape=100000000, chunks=1000000, dtype='i1',
199-
... path='foo/bar')
213+
... path='foo')
200214
>>> sorted(store.keys())
201-
['.zarray', '.zattrs', 'foo/bar/.zarray', 'foo/bar/.zattrs']
202-
>>> print(str(store['foo/bar/.zarray'], 'ascii'))
215+
['.zattrs', '.zgroup', 'foo/.zarray', 'foo/.zattrs']
216+
>>> print(str(store['foo/.zarray'], 'ascii'))
203217
{
204218
"chunks": [
205219
1000000
@@ -231,6 +245,10 @@ def init_array(store, shape, chunks, dtype=None, compressor='default',
231245
# normalize path
232246
path = normalize_storage_path(path)
233247

248+
# ensure parent group initialized
249+
_require_parent_group(path, store=store, chunk_store=chunk_store,
250+
overwrite=overwrite)
251+
234252
# guard conditions
235253
if overwrite:
236254
# attempt to delete any pre-existing items in store
@@ -304,7 +322,18 @@ def init_group(store, overwrite=False, path=None, chunk_store=None):
304322

305323
# normalize path
306324
path = normalize_storage_path(path)
307-
325+
326+
# ensure parent group initialized
327+
_require_parent_group(path, store=store, chunk_store=chunk_store,
328+
overwrite=overwrite)
329+
330+
# initialise metadata
331+
_init_group_metadata(store=store, overwrite=overwrite, path=path,
332+
chunk_store=chunk_store)
333+
334+
335+
def _init_group_metadata(store, overwrite=False, path=None, chunk_store=None):
336+
308337
# guard conditions
309338
if overwrite:
310339
# attempt to delete any pre-existing items in store

0 commit comments

Comments
 (0)