Skip to content

Commit 2ef51f7

Browse files
committed
fixup
1 parent b390d7d commit 2ef51f7

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/zarr/store/memory.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
from zarr.core.common import AccessModeLiteral
1818

1919

20+
# T = TypeVar("T", bound=Buffer | gpu.Buffer)
21+
22+
23+
# class _MemoryStore
24+
25+
2026
# TODO: this store could easily be extended to wrap any MutableMapping store from v2
2127
# When that is done, the `MemoryStore` will just be a store that wraps a dict.
2228
class MemoryStore(Store):
@@ -161,19 +167,23 @@ class GpuMemoryStore(MemoryStore):
161167
of the original location. This guarantees that chunks will always be in GPU
162168
memory for downstream processing. For location agnostic use cases, it would
163169
be better to use `MemoryStore` instead.
170+
171+
Parameters
172+
----------
173+
store_dict: MutableMapping, optional
174+
A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer`
175+
values.
164176
"""
165177

166178
_store_dict: MutableMapping[str, Buffer]
167179

168180
def __init__(
169181
self,
170-
store_dict: MutableMapping[str, Buffer] | None = None,
182+
store_dict: MutableMapping[str, gpu.Buffer] | None = None,
171183
*,
172184
mode: AccessModeLiteral = "r",
173185
) -> None:
174-
super().__init__(mode=mode)
175-
if store_dict:
176-
self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)}
186+
super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type]
177187

178188
def __str__(self) -> str:
179189
return f"gpumemory://{id(self._store_dict)}"

tests/v3/test_store/test_memory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,8 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:
8080

8181
def test_list_prefix(self, store: GpuMemoryStore) -> None:
8282
assert True
83+
84+
def test_dict_reference(self, store: GpuMemoryStore) -> None:
85+
store_dict = {}
86+
result = GpuMemoryStore(store_dict=store_dict)
87+
assert result._store_dict is store_dict

0 commit comments

Comments
 (0)