|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from typing import TYPE_CHECKING |
| 3 | +from typing import TYPE_CHECKING, Self |
4 | 4 |
|
5 | 5 | from zarr.abc.store import ByteRangeRequest, Store
|
6 | 6 | from zarr.core.buffer import Buffer, gpu
|
@@ -41,6 +41,9 @@ async def empty(self) -> bool:
|
41 | 41 | async def clear(self) -> None:
|
42 | 42 | self._store_dict.clear()
|
43 | 43 |
|
| 44 | + def with_mode(self, mode: AccessModeLiteral) -> Self: |
| 45 | + return type(self)(store_dict=self._store_dict, mode=mode) |
| 46 | + |
44 | 47 | def __str__(self) -> str:
|
45 | 48 | return f"memory://{id(self._store_dict)}"
|
46 | 49 |
|
@@ -156,29 +159,58 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
|
156 | 159 |
|
157 | 160 | class GpuMemoryStore(MemoryStore):
|
158 | 161 | """A GPU only memory store that stores every chunk in GPU memory irrespective
|
159 |
| - of the original location. This guarantees that chunks will always be in GPU |
160 |
| - memory for downstream processing. For location agnostic use cases, it would |
161 |
| - be better to use `MemoryStore` instead. |
| 162 | + of the original location. |
| 163 | +
|
| 164 | + The dictionary of buffers to initialize this memory store with *must* be |
| 165 | + GPU Buffers. |
| 166 | +
|
| 167 | + Writing data to this store through ``.set`` will move the buffer to the GPU |
| 168 | + if necessary. |
| 169 | +
|
| 170 | + Parameters |
| 171 | + ---------- |
| 172 | + store_dict: MutableMapping, optional |
| 173 | + A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer` |
| 174 | + values. |
162 | 175 | """
|
163 | 176 |
|
164 |
| - _store_dict: MutableMapping[str, Buffer] |
| 177 | + _store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment] |
165 | 178 |
|
166 | 179 | def __init__(
|
167 | 180 | self,
|
168 |
| - store_dict: MutableMapping[str, Buffer] | None = None, |
| 181 | + store_dict: MutableMapping[str, gpu.Buffer] | None = None, |
169 | 182 | *,
|
170 | 183 | mode: AccessModeLiteral = "r",
|
171 | 184 | ) -> None:
|
172 |
| - super().__init__(mode=mode) |
173 |
| - if store_dict: |
174 |
| - self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)} |
| 185 | + super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type] |
175 | 186 |
|
176 | 187 | def __str__(self) -> str:
|
177 | 188 | return f"gpumemory://{id(self._store_dict)}"
|
178 | 189 |
|
179 | 190 | def __repr__(self) -> str:
|
180 | 191 | return f"GpuMemoryStore({str(self)!r})"
|
181 | 192 |
|
| 193 | + @classmethod |
| 194 | + def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self: |
| 195 | + """ |
| 196 | + Create a GpuMemoryStore from a dictionary of buffers at any location. |
| 197 | +
|
| 198 | + The dictionary backing the newly created ``GpuMemoryStore`` will not be |
| 199 | + the same as ``store_dict``. |
| 200 | +
|
| 201 | + Parameters |
| 202 | + ---------- |
| 203 | + store_dict: mapping |
| 204 | + A mapping of strings keys to arbitrary Buffers. The buffer data |
| 205 | + will be moved into a :class:`gpu.Buffer`. |
| 206 | +
|
| 207 | + Returns |
| 208 | + ------- |
| 209 | + GpuMemoryStore |
| 210 | + """ |
| 211 | + gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()} |
| 212 | + return cls(gpu_store_dict) |
| 213 | + |
182 | 214 | async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
|
183 | 215 | self._check_writable()
|
184 | 216 | assert isinstance(key, str)
|
|
0 commit comments