1616 from zarr .core .common import AccessModeLiteral
1717
1818
19- # T = TypeVar("T", bound=Buffer | gpu.Buffer)
20-
21-
22- # class _MemoryStore
23-
24-
2519# TODO: this store could easily be extended to wrap any MutableMapping store from v2
2620# When that is done, the `MemoryStore` will just be a store that wraps a dict.
2721class MemoryStore (Store ):
@@ -163,9 +157,13 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
163157
164158class GpuMemoryStore (MemoryStore ):
165159 """A GPU only memory store that stores every chunk in GPU memory irrespective
166- of the original location. This guarantees that chunks will always be in GPU
167- memory for downstream processing. For location agnostic use cases, it would
168- be better to use `MemoryStore` instead.
160+ of the original location.
161+
162+ The dictionary of buffers to initialize this memory store with *must* be
163+ GPU Buffers.
164+
165+ Writing data to this store through ``.set`` will move the buffer to the GPU
166+ if necessary.
169167
170168 Parameters
171169 ----------
@@ -174,7 +172,7 @@ class GpuMemoryStore(MemoryStore):
174172 values.
175173 """
176174
177- _store_dict : MutableMapping [str , Buffer ]
175+ _store_dict : MutableMapping [str , gpu . Buffer ] # type: ignore[assignment ]
178176
179177 def __init__ (
180178 self ,
@@ -190,6 +188,27 @@ def __str__(self) -> str:
190188 def __repr__ (self ) -> str :
191189 return f"GpuMemoryStore({ str (self )!r} )"
192190
191+ @classmethod
192+ def from_dict (cls , store_dict : MutableMapping [str , Buffer ]) -> Self :
193+ """
194+ Create a GpuMemoryStore from a dictionary of buffers at any location.
195+
196+ The dictionary backing the newly created ``GpuMemoryStore`` will not be
197+ the same as ``store_dict``.
198+
199+ Parameters
200+ ----------
201+ store_dict: mapping
202+ A mapping of strings keys to arbitrary Buffers. The buffer data
203+ will be moved into a :class:`gpu.Buffer`.
204+
205+ Returns
206+ -------
207+ GpuMemoryStore
208+ """
209+ gpu_store_dict = {k : gpu .Buffer .from_buffer (v ) for k , v in store_dict .items ()}
210+ return cls (gpu_store_dict )
211+
193212 async def set (self , key : str , value : Buffer , byte_range : tuple [int , int ] | None = None ) -> None :
194213 self ._check_writable ()
195214 assert isinstance (key , str )
0 commit comments