|
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 |
|
6 | | -import zarr.api.asynchronous |
7 | 6 | from zarr.abc.store import AccessMode, Store |
8 | 7 | from zarr.core.buffer import Buffer, default_buffer_prototype |
| 8 | +from zarr.core.sync import _collect_aiterator |
9 | 9 | from zarr.store._utils import _normalize_interval_index |
10 | 10 | from zarr.testing.utils import assert_bytes_equal |
11 | 11 |
|
@@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: |
123 | 123 | observed = self.get(store, key) |
124 | 124 | assert_bytes_equal(observed, data_buf) |
125 | 125 |
|
| 126 | + async def test_set_many(self, store: S) -> None: |
| 127 | + """ |
| 128 | + Test that a dict of key : value pairs can be inserted into the store via the |
| 129 | + `_set_many` method. |
| 130 | + """ |
| 131 | + keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] |
| 132 | + data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] |
| 133 | + store_dict = dict(zip(keys, data_buf, strict=True)) |
| 134 | + await store._set_many(store_dict.items()) |
| 135 | + for k, v in store_dict.items(): |
| 136 | + assert self.get(store, k).to_bytes() == v.to_bytes() |
| 137 | + |
126 | 138 | @pytest.mark.parametrize( |
127 | 139 | "key_ranges", |
128 | 140 | ( |
@@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None: |
185 | 197 | assert await store.empty() |
186 | 198 |
|
187 | 199 | async def test_list(self, store: S) -> None: |
188 | | - assert [k async for k in store.list()] == [] |
189 | | - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) |
190 | | - keys = [k async for k in store.list()] |
191 | | - assert keys == ["foo/zarr.json"], keys |
192 | | - |
193 | | - expected = ["foo/zarr.json"] |
194 | | - for i in range(10): |
195 | | - key = f"foo/c/{i}" |
196 | | - expected.append(key) |
197 | | - await store.set( |
198 | | - f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little")) |
199 | | - ) |
| 200 | + assert await _collect_aiterator(store.list()) == () |
| 201 | + prefix = "foo" |
| 202 | + data = self.buffer_cls.from_bytes(b"") |
| 203 | + store_dict = { |
| 204 | + prefix + "/zarr.json": data, |
| 205 | + **{prefix + f"/c/{idx}": data for idx in range(10)}, |
| 206 | + } |
| 207 | + await store._set_many(store_dict.items()) |
| 208 | + expected_sorted = sorted(store_dict.keys()) |
| 209 | + observed = await _collect_aiterator(store.list()) |
| 210 | + observed_sorted = sorted(observed) |
| 211 | + assert observed_sorted == expected_sorted |
200 | 212 |
|
201 | | - @pytest.mark.xfail |
202 | 213 | async def test_list_prefix(self, store: S) -> None: |
203 | | - # TODO: we currently don't use list_prefix anywhere |
204 | | - raise NotImplementedError |
| 214 | + """ |
| 215 | + Test that the `list_prefix` method works as intended. Given a prefix, it should return |
| 216 | + all the keys in storage that start with this prefix. Keys should be returned with the shared |
| 217 | + prefix removed. |
| 218 | + """ |
| 219 | + prefixes = ("", "a/", "a/b/", "a/b/c/") |
| 220 | + data = self.buffer_cls.from_bytes(b"") |
| 221 | + fname = "zarr.json" |
| 222 | + store_dict = {p + fname: data for p in prefixes} |
| 223 | + |
| 224 | + await store._set_many(store_dict.items()) |
| 225 | + |
| 226 | + for prefix in prefixes: |
| 227 | + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) |
| 228 | + expected: tuple[str, ...] = () |
| 229 | + for key in store_dict.keys(): |
| 230 | + if key.startswith(prefix): |
| 231 | + expected += (key.removeprefix(prefix),) |
| 232 | + expected = tuple(sorted(expected)) |
| 233 | + assert observed == expected |
205 | 234 |
|
206 | 235 | async def test_list_dir(self, store: S) -> None: |
207 | | - out = [k async for k in store.list_dir("")] |
208 | | - assert out == [] |
209 | | - assert [k async for k in store.list_dir("foo")] == [] |
210 | | - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) |
211 | | - await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group |
212 | | - await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group |
213 | | - await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) |
214 | | - await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01")) |
215 | | - await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01")) |
216 | | - |
217 | | - keys_expected = ["foo", "group-0"] |
218 | | - keys_observed = [k async for k in store.list_dir("")] |
219 | | - assert set(keys_observed) == set(keys_expected) |
220 | | - |
221 | | - keys_expected = ["zarr.json"] |
222 | | - keys_observed = [k async for k in store.list_dir("foo")] |
223 | | - |
224 | | - assert len(keys_observed) == len(keys_expected), keys_observed |
225 | | - assert set(keys_observed) == set(keys_expected), keys_observed |
226 | | - |
227 | | - keys_observed = [k async for k in store.list_dir("foo/")] |
228 | | - assert len(keys_expected) == len(keys_observed), keys_observed |
229 | | - assert set(keys_observed) == set(keys_expected), keys_observed |
230 | | - |
231 | | - keys_observed = [k async for k in store.list_dir("group-0")] |
232 | | - keys_expected = ["zarr.json", "group-1"] |
233 | | - |
234 | | - assert len(keys_observed) == len(keys_expected), keys_observed |
235 | | - assert set(keys_observed) == set(keys_expected), keys_observed |
236 | | - |
237 | | - keys_observed = [k async for k in store.list_dir("group-0/")] |
238 | | - assert len(keys_expected) == len(keys_observed), keys_observed |
239 | | - assert set(keys_observed) == set(keys_expected), keys_observed |
| 236 | + root = "foo" |
| 237 | + store_dict = { |
| 238 | + root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), |
| 239 | + root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), |
| 240 | + } |
240 | 241 |
|
241 | | - keys_observed = [k async for k in store.list_dir("group-0/group-1")] |
242 | | - keys_expected = ["zarr.json", "a1", "a2", "a3"] |
| 242 | + assert await _collect_aiterator(store.list_dir("")) == () |
| 243 | + assert await _collect_aiterator(store.list_dir(root)) == () |
243 | 244 |
|
244 | | - assert len(keys_observed) == len(keys_expected), keys_observed |
245 | | - assert set(keys_observed) == set(keys_expected), keys_observed |
| 245 | + await store._set_many(store_dict.items()) |
246 | 246 |
|
247 | | - keys_observed = [k async for k in store.list_dir("group-0/group-1")] |
248 | | - assert len(keys_expected) == len(keys_observed), keys_observed |
249 | | - assert set(keys_observed) == set(keys_expected), keys_observed |
| 247 | + keys_observed = await _collect_aiterator(store.list_dir(root)) |
| 248 | + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} |
250 | 249 |
|
251 | | - async def test_set_get(self, store_kwargs: dict[str, Any]) -> None: |
252 | | - kwargs = {**store_kwargs, **{"mode": "w"}} |
253 | | - store = self.store_cls(**kwargs) |
254 | | - await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,)) |
255 | | - keys = [x async for x in store.list()] |
256 | | - assert keys == ["a/zarr.json"] |
| 250 | + assert sorted(keys_observed) == sorted(keys_expected) |
257 | 251 |
|
258 | | - # no errors |
259 | | - await zarr.api.asynchronous.open_array(store=store, path="a", mode="r") |
260 | | - await zarr.api.asynchronous.open_array(store=store, path="a", mode="a") |
| 252 | + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) |
| 253 | + assert sorted(keys_expected) == sorted(keys_observed) |
0 commit comments