Skip to content

Commit 5483956

Browse files
committed
add creation from other zarr
1 parent b9699f5 commit 5483956

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

src/zarr/api/asynchronous.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
ChunkCoords,
1919
MemoryOrder,
2020
ZarrFormat,
21+
parse_dtype,
22+
concurrent_map,
2123
_warn_order_kwarg,
2224
_warn_write_empty_chunks_kwarg,
2325
parse_dtype,
@@ -551,6 +553,21 @@ async def array(
551553
The new array.
552554
"""
553555

556+
if isinstance(data, Array):
557+
chunks = kwargs.pop("chunks", None) or data.chunks
558+
new_array = await create(shape=data.shape, chunks=chunks, dtype=data.dtype, **kwargs)
559+
560+
async def _copy_chunk(chunk_coords: ChunkCoords) -> None:
561+
await new_array.setitem(chunk_coords, await data._async_array.getitem(chunk_coords))
562+
563+
# Stream data from the source array to the new array
564+
await concurrent_map(
565+
[(region,) for region in data._iter_chunk_regions()],
566+
_copy_chunk,
567+
config.get("async.concurrency"),
568+
)
569+
return new_array
570+
554571
# ensure data is array-like
555572
if not hasattr(data, "shape") or not hasattr(data, "dtype"):
556573
data = np.asanyarray(data)

src/zarr/api/synchronous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def tree(grp: Group, expand: bool | None = None, level: int | None = None) -> An
339339

340340

341341
# TODO: add type annotations for kwargs
342-
def array(data: npt.ArrayLike, **kwargs: Any) -> Array:
342+
def array(data: npt.ArrayLike | Array, **kwargs: Any) -> Array:
343343
"""Create an array filled with `data`.
344344
345345
Parameters

tests/test_array.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import math
44
import pickle
5+
import time
56
from itertools import accumulate
67
from typing import Any, Literal
78

@@ -865,6 +866,52 @@ async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected:
865866
assert actual["fill_value"] == expected
866867

867868

869+
async def test_creation_from_other_zarr(tmpdir):
870+
src = zarr.zeros(
871+
(2000, 20000), chunks=(1000, 1000), dtype="uint8", store=LocalStore(str(tmpdir))
872+
)
873+
src[:] = 1
874+
for _i in range(10):
875+
start_time = time.time()
876+
c = zarr.array(src, store=MemoryStore())
877+
end_time = time.time()
878+
print(f"Time fast: {end_time - start_time} seconds")
879+
880+
start_time = time.time()
881+
b = zarr.zeros(src.shape, chunks=src.chunks, store=MemoryStore())
882+
b[:] = src[:]
883+
end_time = time.time()
884+
print(f"Time slow: {end_time - start_time} seconds")
885+
886+
assert b[123, 123] == 1
887+
assert c[123, 123] == 1
888+
889+
890+
@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)])
891+
@pytest.mark.parametrize("dtype", ["uint8", "float32"])
892+
@pytest.mark.parametrize("array_type", ["async", "sync"])
893+
async def test_nbytes(
894+
shape: tuple[int, ...], dtype: str, array_type: Literal["async", "sync"]
895+
) -> None:
896+
"""
897+
Test that the ``nbytes`` attribute of an Array or AsyncArray correctly reports the capacity of
898+
the chunks of that array.
899+
"""
900+
store = MemoryStore()
901+
arr = Array.create(store=store, shape=shape, dtype=dtype, fill_value=0)
902+
if array_type == "async":
903+
assert arr._async_array.nbytes == np.prod(arr.shape) * arr.dtype.itemsize
904+
else:
905+
assert arr.nbytes == np.prod(arr.shape) * arr.dtype.itemsize
906+
907+
908+
async def test_scalar_array() -> None:
909+
arr = zarr.array(1.5)
910+
assert arr[...] == 1.5
911+
assert arr[()] == 1.5
912+
assert arr.shape == ()
913+
914+
868915
@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)])
869916
@pytest.mark.parametrize("dtype", ["uint8", "float32"])
870917
@pytest.mark.parametrize("array_type", ["async", "sync"])

0 commit comments

Comments
 (0)