Skip to content

Commit 36309eb

Browse files
committed
resolves #2 by adding support for axis argument to resize; also fixes bug in >1D resize
1 parent 9796503 commit 36309eb

File tree

3 files changed

+128
-24
lines changed

3 files changed

+128
-24
lines changed

zarr/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def empty(shape, chunks, dtype=None, cname=None, clevel=None, shuffle=None,
4545

4646
def zeros(shape, chunks, dtype=None, cname=None, clevel=None, shuffle=None,
4747
synchronized=True):
48-
"""Create an array filled with zeros.
48+
"""Create an array, with zero being used as the default value for
49+
uninitialised portions of the array.
4950
5051
Parameters
5152
----------
@@ -80,7 +81,8 @@ def zeros(shape, chunks, dtype=None, cname=None, clevel=None, shuffle=None,
8081

8182
def ones(shape, chunks, dtype=None, cname=None, clevel=None, shuffle=None,
8283
synchronized=True):
83-
"""Create an array filled with ones.
84+
"""Create an array, with one being used as the default value for
85+
uninitialised portions of the array.
8486
8587
Parameters
8688
----------
@@ -116,7 +118,8 @@ def ones(shape, chunks, dtype=None, cname=None, clevel=None, shuffle=None,
116118

117119
def full(shape, chunks, fill_value, dtype=None, cname=None, clevel=None,
118120
shuffle=None, synchronized=True):
119-
"""Create an array filled with `fill_value`.
121+
"""Create an array, with `fill_value` being used as the default value for
122+
uninitialised portions of the array.
120123
121124
Parameters
122125
----------
@@ -174,6 +177,8 @@ def array(data, chunks=None, dtype=None, cname=None, clevel=None,
174177
synchronized : bool, optional
175178
If True, each chunk will be protected with a lock to prevent data
176179
collision during write operations.
180+
fill_value : object
181+
Default value to use for uninitialised portions of the array.
177182
178183
Returns
179184
-------

zarr/ext.pyx

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,33 @@ cdef class Array:
583583
return r
584584

585585
def resize(self, *args):
586-
"""TODO"""
586+
"""Resize the array by growing or shrinking one or more dimensions.
587+
588+
Parameters
589+
----------
590+
args : int or sequence of ints
591+
New shape to resize to.
592+
593+
Notes
594+
-----
595+
This function can only be used to change the size of existing
596+
dimensions, it cannot add or drop dimensions.
597+
598+
N.B., this function does *not* behave in the same way as the numpy
599+
resize method on the ndarray class. Existing data are *not*
600+
reorganised in any way. Axes are simply grown or shrunk. When
601+
growing an axis, uninitialised portions of the array will appear to
602+
contain the value of the `fill_value` attribute on this array,
603+
and when shrinking an array any data beyond the new shape will be
604+
lost (although see note below).
605+
606+
N.B., because of the way the underlying chunks are organised,
607+
and in particular the fact that chunks may overhang the edge of the
608+
array, the value of uninitialised portions of this array is not
609+
guaranteed to respect the setting of the `fill_value` attribute when
610+
shrinking then regrowing an array.
611+
612+
"""
587613

588614
# normalise new shape argument
589615
if len(args) == 1:
@@ -601,15 +627,24 @@ cdef class Array:
601627
new_shape = tuple(s if n is None else n
602628
for s, n in zip(self.shape, new_shape))
603629

604-
# set new shape
605-
self.shape = new_shape
630+
# work-around Cython problems with accessing .shape attribute as tuple
631+
old_cdata = np.asarray(self.cdata)
632+
633+
# remember old cdata shape
634+
old_cdata_shape = old_cdata.shape
606635

607636
# determine the new number and arrangement of chunks
608637
new_cdata_shape = tuple(int(np.ceil(s / c))
609638
for s, c in zip(new_shape, self.chunks))
610639

611-
# resize chunks array
612-
self.cdata.resize(new_cdata_shape, refcheck=False)
640+
# setup new chunks array
641+
new_cdata = np.empty(new_cdata_shape, dtype=object)
642+
643+
# copy across any chunks to be kept
644+
cdata_overlap = tuple(
645+
slice(min(o, n)) for o, n in zip(old_cdata_shape, new_cdata_shape)
646+
)
647+
new_cdata[cdata_overlap] = old_cdata[cdata_overlap]
613648

614649
# determine function for instantiating chunks
615650
if self.synchronized:
@@ -618,33 +653,64 @@ cdef class Array:
618653
create_chunk = Chunk
619654

620655
# instantiate any new chunks as needed
621-
self.cdata.flat = [create_chunk(self.chunks, dtype=self.dtype,
622-
cname=self.cname,
623-
clevel=self.clevel,
624-
shuffle=self.shuffle,
625-
fill_value=self.fill_value)
626-
if c == 0 else c
627-
for c in self.cdata.flat]
656+
new_cdata.flat = [create_chunk(self.chunks, dtype=self.dtype,
657+
cname=self.cname,
658+
clevel=self.clevel,
659+
shuffle=self.shuffle,
660+
fill_value=self.fill_value)
661+
if c is None else c
662+
for c in new_cdata.flat]
628663

629-
def append(self, data):
630-
"""TODO"""
664+
# set new shape
665+
self.shape = new_shape
666+
667+
# set new chunks
668+
self.cdata = new_cdata
669+
670+
def append(self, data, axis=0):
671+
"""Append `data` to `axis`.
672+
673+
Parameters
674+
----------
675+
data : array_like
676+
Data to be appended.
677+
axis : int
678+
Axis along which to append.
679+
680+
Notes
681+
-----
682+
The size of all dimensions other than `axis` must match between this
683+
array and `data`.
684+
685+
"""
631686

632687
# ensure data is array-like
633688
if not hasattr(data, 'shape') or not hasattr(data, 'dtype'):
634689
data = np.asanyarray(data)
635690

636-
# ensure shapes are compatible for trailing dimensions
637-
if self.shape[1:] != data.shape[1:]:
638-
raise ValueError('shape not compatible')
691+
# ensure shapes are compatible for non-append dimensions
692+
self_shape_preserved = tuple(s for i, s in enumerate(self.shape)
693+
if i != axis)
694+
data_shape_preserved = tuple(s for i, s in enumerate(data.shape)
695+
if i != axis)
696+
if self_shape_preserved != data_shape_preserved:
697+
raise ValueError('shapes not compatible')
639698

640699
# remember old shape
641700
old_shape = self.shape
642701

643702
# determine new shape
644-
new_shape = (self.shape[0] + data.shape[0],) + self.shape[1:]
703+
new_shape = tuple(
704+
self.shape[i] if i != axis else self.shape[i] + data.shape[i]
705+
for i in range(len(self.shape))
706+
)
645707

646708
# resize
647709
self.resize(new_shape)
648710

649711
# store data
650-
self[old_shape[0]:] = data
712+
append_selection = tuple(
713+
slice(None) if i != axis else slice(old_shape[i], new_shape[i])
714+
for i in range(len(self.shape))
715+
)
716+
self[append_selection] = data

zarr/tests/test_ext_array.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,57 +147,71 @@ def test_array_2d():
147147

148148
def test_resize_1d():
149149

150-
z = Array(105, chunks=10, dtype='i4')
150+
z = Array(105, chunks=10, dtype='i4', fill_value=0)
151+
a = np.arange(105, dtype='i4')
152+
z[:] = a
151153
eq((105,), z.shape)
152154
eq((105,), z[:].shape)
153155
eq(np.dtype('i4'), z.dtype)
154156
eq(np.dtype('i4'), z[:].dtype)
155157
eq((10,), z.chunks)
158+
assert_array_equal(a, z[:])
156159

157160
z.resize(205)
158161
eq((205,), z.shape)
159162
eq((205,), z[:].shape)
160163
eq(np.dtype('i4'), z.dtype)
161164
eq(np.dtype('i4'), z[:].dtype)
162165
eq((10,), z.chunks)
166+
assert_array_equal(a, z[:105])
167+
assert_array_equal(np.zeros(100, dtype='i4'), z[105:])
163168

164169
z.resize(55)
165170
eq((55,), z.shape)
166171
eq((55,), z[:].shape)
167172
eq(np.dtype('i4'), z.dtype)
168173
eq(np.dtype('i4'), z[:].dtype)
169174
eq((10,), z.chunks)
175+
assert_array_equal(a[:55], z[:55])
170176

171177

172178
def test_resize_2d():
173179

174-
z = Array((105, 105), chunks=(10, 10), dtype='i4')
180+
z = Array((105, 105), chunks=(10, 10), dtype='i4', fill_value=0)
181+
a = np.arange(105*105, dtype='i4').reshape((105, 105))
182+
z[:] = a
175183
eq((105, 105), z.shape)
176184
eq((105, 105), z[:].shape)
177185
eq(np.dtype('i4'), z.dtype)
178186
eq(np.dtype('i4'), z[:].dtype)
179187
eq((10, 10), z.chunks)
188+
assert_array_equal(a, z[:])
180189

181190
z.resize((205, 205))
182191
eq((205, 205), z.shape)
183192
eq((205, 205), z[:].shape)
184193
eq(np.dtype('i4'), z.dtype)
185194
eq(np.dtype('i4'), z[:].dtype)
186195
eq((10, 10), z.chunks)
196+
assert_array_equal(a, z[:105, :105])
197+
assert_array_equal(np.zeros((100, 205), dtype='i4'), z[105:, :])
198+
assert_array_equal(np.zeros((205, 100), dtype='i4'), z[:, 105:])
187199

188200
z.resize((55, 55))
189201
eq((55, 55), z.shape)
190202
eq((55, 55), z[:].shape)
191203
eq(np.dtype('i4'), z.dtype)
192204
eq(np.dtype('i4'), z[:].dtype)
193205
eq((10, 10), z.chunks)
206+
assert_array_equal(a[:55, :55], z[:])
194207

195208
z.resize((55, 1))
196209
eq((55, 1), z.shape)
197210
eq((55, 1), z[:].shape)
198211
eq(np.dtype('i4'), z.dtype)
199212
eq(np.dtype('i4'), z[:].dtype)
200213
eq((10, 10), z.chunks)
214+
assert_array_equal(a[:55, :1], z[:])
201215

202216

203217
def test_append_1d():
@@ -236,3 +250,22 @@ def test_append_2d():
236250
eq(e.dtype, z.dtype)
237251
eq((10, 10), z.chunks)
238252
assert_array_equal(e, z[:])
253+
254+
255+
def test_append_2d_axis():
256+
257+
a = np.arange(105*105, dtype='i4').reshape((105, 105))
258+
z = Array(a.shape, chunks=(10, 10), dtype=a.dtype)
259+
z[:] = a
260+
eq(a.shape, z.shape)
261+
eq(a.dtype, z.dtype)
262+
eq((10, 10), z.chunks)
263+
assert_array_equal(a, z[:])
264+
265+
b = np.arange(105*105, 2*105*105, dtype='i4').reshape((105, 105))
266+
e = np.append(a, b, axis=1)
267+
z.append(b, axis=1)
268+
eq(e.shape, z.shape)
269+
eq(e.dtype, z.dtype)
270+
eq((10, 10), z.chunks)
271+
assert_array_equal(e, z[:])

0 commit comments

Comments
 (0)