Skip to content

Commit 021ca95

Browse files
committed
add from_array with npt.ArrayLike
1 parent 97b7b9b commit 021ca95

File tree

3 files changed

+71
-38
lines changed

3 files changed

+71
-38
lines changed

src/zarr/api/synchronous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ def create_array(
895895

896896

897897
def from_array(
898-
data: Array,
898+
data: Array | npt.ArrayLike,
899899
store: str | StoreLike,
900900
*,
901901
name: str | None = None,

src/zarr/core/array.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3851,37 +3851,48 @@ async def from_array(
38513851
#TODO
38523852
"""
38533853

3854-
if chunks == "keep":
3855-
chunks = data.chunks
3856-
if zarr_format is None:
3857-
zarr_format = data.metadata.zarr_format
3858-
if filters == "keep":
3859-
if zarr_format == data.metadata.zarr_format:
3860-
filters = data.filters or None
3861-
else:
3854+
if isinstance(data, Array):
3855+
if chunks == "keep":
3856+
chunks = data.chunks
3857+
if zarr_format is None:
3858+
zarr_format = data.metadata.zarr_format
3859+
if filters == "keep":
3860+
if zarr_format == data.metadata.zarr_format:
3861+
filters = data.filters or None
3862+
else:
3863+
filters = "auto"
3864+
if compressors == "keep":
3865+
if zarr_format == data.metadata.zarr_format:
3866+
compressors = data.compressors or None
3867+
else:
3868+
compressors = "auto"
3869+
if serializer == "keep":
3870+
if zarr_format == 3:
3871+
serializer = cast(SerializerLike, data.serializer)
3872+
else:
3873+
serializer = "auto"
3874+
if fill_value is None:
3875+
fill_value = data.fill_value
3876+
if order is None:
3877+
order = data.order
3878+
if chunk_key_encoding is None and zarr_format == data.metadata.zarr_format:
3879+
if isinstance(data.metadata, ArrayV2Metadata):
3880+
chunk_key_encoding = {"name": "v2", "separator": data.metadata.dimension_separator}
3881+
elif isinstance(data.metadata, ArrayV3Metadata):
3882+
chunk_key_encoding = data.metadata.chunk_key_encoding
3883+
if dimension_names is None and data.metadata.zarr_format == 3:
3884+
dimension_names = data.metadata.dimension_names
3885+
else:
3886+
if chunks == "keep":
3887+
chunks = "auto"
3888+
if zarr_format is None:
3889+
zarr_format = 3
3890+
if filters == "keep":
38623891
filters = "auto"
3863-
if compressors == "keep":
3864-
if zarr_format == data.metadata.zarr_format:
3865-
compressors = data.compressors or None
3866-
else:
3892+
if compressors == "keep":
38673893
compressors = "auto"
3868-
if serializer == "keep":
3869-
if zarr_format == 3:
3870-
serializer = cast(SerializerLike, data.serializer)
3871-
else:
3894+
if serializer == "keep":
38723895
serializer = "auto"
3873-
if fill_value is None:
3874-
fill_value = data.fill_value
3875-
if order is None:
3876-
order = data.order
3877-
if chunk_key_encoding is None and zarr_format == data.metadata.zarr_format:
3878-
if isinstance(data.metadata, ArrayV2Metadata):
3879-
chunk_key_encoding = {"name": "v2", "separator": data.metadata.dimension_separator}
3880-
elif isinstance(data.metadata, ArrayV3Metadata):
3881-
chunk_key_encoding = data.metadata.chunk_key_encoding
3882-
if dimension_names is None and data.metadata.zarr_format == 3:
3883-
dimension_names = data.metadata.dimension_names
3884-
38853896
new_array = await create_array(
38863897
store,
38873898
name=name,
@@ -3902,17 +3913,29 @@ async def from_array(
39023913
overwrite=overwrite,
39033914
config=config,
39043915
)
3916+
if isinstance(data, Array):
39053917

3906-
async def _copy_region(chunk_coords: ChunkCoords | slice, _data: Array) -> None:
3907-
arr = await _data._async_array.getitem(chunk_coords)
3908-
await new_array.setitem(chunk_coords, arr)
3918+
async def _copy_region(chunk_coords: ChunkCoords | slice, _data: Array) -> None:
3919+
arr = await _data._async_array.getitem(chunk_coords)
3920+
await new_array.setitem(chunk_coords, arr)
39093921

3910-
# Stream data from the source array to the new array
3911-
await concurrent_map(
3912-
[(region, data) for region in new_array._iter_chunk_regions()],
3913-
_copy_region,
3914-
zarr.core.config.config.get("async.concurrency"),
3915-
)
3922+
# Stream data from the source array to the new array
3923+
await concurrent_map(
3924+
[(region, data) for region in new_array._iter_chunk_regions()],
3925+
_copy_region,
3926+
zarr.core.config.config.get("async.concurrency"),
3927+
)
3928+
else:
3929+
3930+
async def _copy_region(chunk_coords: ChunkCoords | slice, _data: npt.ArrayLike) -> None:
3931+
await new_array.setitem(chunk_coords, _data[chunk_coords])
3932+
3933+
# Stream data from the source array to the new array
3934+
await concurrent_map(
3935+
[(region, data) for region in new_array._iter_chunk_regions()],
3936+
_copy_region,
3937+
zarr.core.config.config.get("async.concurrency"),
3938+
)
39163939
return new_array
39173940

39183941

tests/test_array.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,3 +1338,13 @@ async def test_from_array(
13381338
assert result.dtype == src_dtype
13391339
assert result.attrs == new_attributes
13401340
assert result.chunks == new_chunks
1341+
1342+
1343+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True)
1344+
@pytest.mark.parametrize("chunks", [(10, 2, 3), "keep", "auto"])
1345+
async def test_from_numpy_array(
1346+
store: Store, chunks: Literal["auto", "keep"] | tuple[int, int]
1347+
) -> None:
1348+
src = np.arange(1000).reshape(10, 10, 10)
1349+
result = zarr.from_array(src, store=store, chunks=chunks)
1350+
np.testing.assert_array_equal(result[:], src)

0 commit comments

Comments
 (0)