11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Self
3+ from typing import TYPE_CHECKING
44
55from zarr .abc .store import ByteRangeRequest , Store
66from zarr .core .buffer import Buffer , gpu
99
1010if TYPE_CHECKING :
1111 from collections .abc import AsyncGenerator , Iterable , MutableMapping
12+ from typing import Self
1213
1314 from zarr .core .buffer import BufferPrototype
1415 from zarr .core .common import AccessModeLiteral
@@ -26,9 +27,9 @@ class MemoryStore(Store):
2627
2728 def __init__ (
2829 self ,
29- path : str = "" ,
3030 store_dict : MutableMapping [str , Buffer ] | None = None ,
3131 * ,
32+ path : str = "" ,
3233 mode : AccessModeLiteral = "r" ,
3334 ) -> None :
3435 super ().__init__ (mode = mode , path = path )
@@ -68,9 +69,9 @@ async def get(
6869 ) -> Buffer | None :
6970 if not self ._is_open :
7071 await self ._open ()
71- assert isinstance ( key , str )
72+
7273 try :
73- value = self ._store_dict [key ]
74+ value = self ._store_dict [self . resolve_key ( key ) ]
7475 start , length = _normalize_interval_index (value , byte_range )
7576 return prototype .buffer .from_buffer (value [start : start + length ])
7677 except KeyError :
@@ -88,45 +89,46 @@ async def _get(key: str, byte_range: ByteRangeRequest) -> Buffer | None:
8889 return await concurrent_map (key_ranges , _get , limit = None )
8990
9091 async def exists (self , key : str ) -> bool :
91- return key in self ._store_dict
92+ return self . resolve_key ( key ) in self ._store_dict
9293
9394 async def set (self , key : str , value : Buffer , byte_range : tuple [int , int ] | None = None ) -> None :
9495 self ._check_writable ()
9596 await self ._ensure_open ()
9697 assert isinstance (key , str )
9798 if not isinstance (value , Buffer ):
9899 raise TypeError (f"Expected Buffer. Got { type (value )} ." )
99-
100+ key_abs = self . resolve_key ( key )
100101 if byte_range is not None :
101- buf = self ._store_dict [key ]
102+ buf = self ._store_dict [key_abs ]
102103 buf [byte_range [0 ] : byte_range [1 ]] = value
103- self ._store_dict [key ] = buf
104+ self ._store_dict [key_abs ] = buf
104105 else :
105- self ._store_dict [key ] = value
106+ self ._store_dict [key_abs ] = value
106107
107108 async def set_if_not_exists (self , key : str , default : Buffer ) -> None :
108109 self ._check_writable ()
109110 await self ._ensure_open ()
110- self ._store_dict .setdefault (key , default )
111+ self ._store_dict .setdefault (self . resolve_key ( key ) , default )
111112
112113 async def delete (self , key : str ) -> None :
113114 self ._check_writable ()
114115 try :
115- del self ._store_dict [key ]
116+ del self ._store_dict [self . resolve_key ( key ) ]
116117 except KeyError :
117118 pass # Q(JH): why not raise?
118119
119120 async def set_partial_values (self , key_start_values : Iterable [tuple [str , int , bytes ]]) -> None :
120121 raise NotImplementedError
121122
122123 async def list (self ) -> AsyncGenerator [str , None ]:
123- for key in self ._store_dict :
124- yield key
124+ async for result in self .list_prefix ( "" ) :
125+ yield result
125126
126127 async def list_prefix (self , prefix : str ) -> AsyncGenerator [str , None ]:
128+ prefix_abs = self .resolve_key (prefix )
127129 for key in self ._store_dict :
128- if key .startswith (prefix ):
129- yield key .removeprefix (prefix )
130+ if key .startswith (prefix_abs ):
131+ yield key .removeprefix (prefix_abs ). lstrip ( "/" )
130132
131133 async def list_dir (self , prefix : str ) -> AsyncGenerator [str , None ]:
132134 """
@@ -141,8 +143,7 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
141143 -------
142144 AsyncGenerator[str, None]
143145 """
144- if prefix .endswith ("/" ):
145- prefix = prefix [:- 1 ]
146+ prefix = self .resolve_key (prefix )
146147
147148 if prefix == "" :
148149 keys_unique = {k .split ("/" )[0 ] for k in self ._store_dict }
0 commit comments