Skip to content

Commit d7bb121

Browse files
committed
ensure tests pass
1 parent 68465db commit d7bb121

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

src/zarr/api/asynchronous.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ async def create_array(
936936
shape: ChunkCoords,
937937
dtype: npt.DTypeLike,
938938
chunk_shape: ChunkCoords,
939-
shard_shape: ChunkCoords | None,
939+
shard_shape: ChunkCoords | None = None,
940940
filters: Iterable[dict[str, JSON] | Codec] = (),
941941
compressors: Iterable[dict[str, JSON] | Codec] = (),
942942
fill_value: Any | None = 0,
@@ -1016,14 +1016,18 @@ async def create_array(
10161016
_dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
10171017
config_parsed = parse_array_config(config)
10181018
if zarr_format == 2:
1019-
if shard_shape is not None or shard_shape != "auto":
1019+
if shard_shape is not None:
10201020
msg = (
10211021
'Zarr v2 arrays can only be created with `shard_shape` set to `None` or `"auto"`.'
10221022
f"Got `shard_shape={shard_shape}` instead."
10231023
)
10241024

10251025
raise ValueError(msg)
1026-
compressor, *rest = compressors
1026+
if len(tuple(compressors)) > 1:
1027+
compressor, *rest = compressors
1028+
else:
1029+
compressor = None
1030+
rest = ()
10271031
filters = (*filters, *rest)
10281032
if dimension_names is not None:
10291033
raise ValueError("Zarr v2 arrays do not support dimension names.")

tests/test_api.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ def test_read(store: Store) -> None:
7070
"""
7171
# create an array and a group
7272
_ = create_group(store=store, path="group", attributes={"node_type": "group"})
73-
_ = create_array(store=store, path="array", shape=(10, 10), attributes={"node_type": "array"})
73+
_ = create_array(
74+
store=store,
75+
path="array",
76+
shape=(10, 10),
77+
chunk_shape=(1, 1),
78+
dtype="uint8",
79+
attributes={"node_type": "array"},
80+
)
7481

7582
group_r = read(store, path="group")
7683
assert isinstance(group_r, Group)
@@ -89,7 +96,9 @@ def test_create_array(store: Store) -> None:
8996
shape = (10, 10)
9097
path = "foo"
9198
data_val = 1
92-
array_w = create_array(store, path=path, shape=shape, attributes=attrs)
99+
array_w = create_array(
100+
store, path=path, shape=shape, attributes=attrs, chunk_shape=shape, dtype="uint8"
101+
)
93102
array_w[:] = data_val
94103
assert array_w.shape == shape
95104
assert array_w.attrs == attrs
@@ -107,7 +116,13 @@ def test_read_array(store: Store) -> None:
107116
for zarr_format in (2, 3):
108117
attrs = {"zarr_format": zarr_format}
109118
node_w = create_array(
110-
store, path=path, shape=shape, attributes=attrs, zarr_format=zarr_format
119+
store,
120+
path=path,
121+
shape=shape,
122+
attributes=attrs,
123+
zarr_format=zarr_format,
124+
chunk_shape=shape,
125+
dtype="uint8",
111126
)
112127
node_w[:] = data_val
113128

@@ -1214,9 +1229,9 @@ async def test_create_array_v2(store: MemoryStore) -> None:
12141229
store=store,
12151230
dtype=dtype,
12161231
shape=(10,),
1217-
shard_shape=(4,),
1232+
shard_shape=None,
12181233
chunk_shape=(4,),
1219-
zarr_format=3,
1234+
zarr_format=2,
12201235
filters=(Delta(dtype=dtype),),
12211236
compressors=(Zstd(level=3),),
12221237
)

0 commit comments

Comments
 (0)