Skip to content

Commit 926c71a

Browse files
committed
fixup
1 parent 3fd8c46 commit 926c71a

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

src/zarr/store/memory.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
from zarr.core.common import AccessModeLiteral
1717

1818

19-
# T = TypeVar("T", bound=Buffer | gpu.Buffer)
20-
21-
22-
# class _MemoryStore
23-
24-
2519
# TODO: this store could easily be extended to wrap any MutableMapping store from v2
2620
# When that is done, the `MemoryStore` will just be a store that wraps a dict.
2721
class MemoryStore(Store):
@@ -163,9 +157,13 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
163157

164158
class GpuMemoryStore(MemoryStore):
165159
"""A GPU only memory store that stores every chunk in GPU memory irrespective
166-
of the original location. This guarantees that chunks will always be in GPU
167-
memory for downstream processing. For location agnostic use cases, it would
168-
be better to use `MemoryStore` instead.
160+
of the original location.
161+
162+
The dictionary of buffers to initialize this memory store with *must* be
163+
GPU Buffers.
164+
165+
Writing data to this store through ``.set`` will move the buffer to the GPU
166+
if necessary.
169167
170168
Parameters
171169
----------
@@ -174,7 +172,7 @@ class GpuMemoryStore(MemoryStore):
174172
values.
175173
"""
176174

177-
_store_dict: MutableMapping[str, Buffer]
175+
_store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment]
178176

179177
def __init__(
180178
self,
@@ -190,6 +188,27 @@ def __str__(self) -> str:
190188
def __repr__(self) -> str:
191189
return f"GpuMemoryStore({str(self)!r})"
192190

191+
@classmethod
192+
def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self:
193+
"""
194+
Create a GpuMemoryStore from a dictionary of buffers at any location.
195+
196+
The dictionary backing the newly created ``GpuMemoryStore`` will not be
197+
the same as ``store_dict``.
198+
199+
Parameters
200+
----------
201+
store_dict: mapping
202+
A mapping of strings keys to arbitrary Buffers. The buffer data
203+
will be moved into a :class:`gpu.Buffer`.
204+
205+
Returns
206+
-------
207+
GpuMemoryStore
208+
"""
209+
gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()}
210+
return cls(gpu_store_dict)
211+
193212
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
194213
self._check_writable()
195214
assert isinstance(key, str)

src/zarr/testing/store.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,23 @@ async def test_with_mode(self, store: S) -> None:
267267
assert isinstance(clone, type(store))
268268

269269
# earlier writes are visible
270-
assert self.get(clone, "key").to_bytes() == data
270+
result = await clone.get("key", default_buffer_prototype())
271+
assert result is not None
272+
assert result.to_bytes() == data
271273

272-
# writes to original after with_mode is visible
274+
# # writes to original after with_mode is visible
273275
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
274-
assert self.get(clone, "key-2").to_bytes() == data
276+
result = await clone.get("key-2", default_buffer_prototype())
277+
assert result is not None
278+
assert result.to_bytes() == data
275279

276-
if mode == "w":
280+
if mode == "a":
277281
# writes to clone is visible in the original
278-
self.set(store, "key-3", self.buffer_cls.from_bytes(data))
279-
assert self.get(clone, "key-3").to_bytes() == data
282+
await clone.set("key-3", self.buffer_cls.from_bytes(data))
283+
result = await clone.get("key-3", default_buffer_prototype())
284+
assert result is not None
285+
assert result.to_bytes() == data
286+
287+
else:
288+
with pytest.raises(ValueError):
289+
await clone.set("key-3", self.buffer_cls.from_bytes(data))

tests/v3/test_store/test_memory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,12 @@ def test_dict_reference(self, store: GpuMemoryStore) -> None:
9090
store_dict = {}
9191
result = GpuMemoryStore(store_dict=store_dict)
9292
assert result._store_dict is store_dict
93+
94+
def test_from_dict(self):
95+
d = {
96+
"a": gpu.Buffer.from_bytes(b"aaaa"),
97+
"b": cpu.Buffer.from_bytes(b"bbbb"),
98+
}
99+
result = GpuMemoryStore.from_dict(d)
100+
for v in result._store_dict.values():
101+
assert type(v) is gpu.Buffer

0 commit comments

Comments
 (0)