77from elastic_transport import SerializationError as ElasticsearchSerializationError
88from key_value .shared .errors import DeserializationError , SerializationError
99from key_value .shared .utils .managed_entry import ManagedEntry
10+ from key_value .shared .utils .sanitization import AlwaysHashStrategy , HashFragmentMode , HybridSanitizationStrategy
1011from key_value .shared .utils .sanitize import (
1112 ALPHANUMERIC_CHARACTERS ,
1213 LOWERCASE_ALPHABET ,
1314 NUMBERS ,
14- sanitize_string ,
15+ UPPERCASE_ALPHABET ,
1516)
1617from key_value .shared .utils .serialization import SerializationAdapter
1718from key_value .shared .utils .time_to_live import now_as_epoch
@@ -145,7 +146,7 @@ class ElasticsearchStore(
145146
146147 _native_storage : bool
147148
148- _adapter : SerializationAdapter
149+ _serializer : SerializationAdapter
149150
150151 @overload
151152 def __init__ (
@@ -207,12 +208,31 @@ def __init__(
207208 LessCapableJsonSerializer .install_default_serializer (client = self ._client )
208209 LessCapableNdjsonSerializer .install_serializer (client = self ._client )
209210
210- self ._index_prefix = index_prefix
211+ self ._index_prefix = index_prefix . lower ()
211212 self ._native_storage = native_storage
212213 self ._is_serverless = False
213- self ._adapter = ElasticsearchSerializationAdapter (native_storage = native_storage )
214214
215- super ().__init__ (default_collection = default_collection )
215+ # We have 240 characters to work with
216+ # We need to account for the index prefix and the hyphen.
217+ max_index_length = MAX_INDEX_LENGTH - (len (self ._index_prefix ) + 1 )
218+
219+ self ._serializer = ElasticsearchSerializationAdapter (native_storage = native_storage )
220+
221+ # We allow uppercase through the sanitizer so we can lowercase them instead of them
222+ # all turning into underscores.
223+ collection_sanitization = HybridSanitizationStrategy (
224+ replacement_character = "_" ,
225+ max_length = max_index_length ,
226+ allowed_characters = UPPERCASE_ALPHABET + ALLOWED_INDEX_CHARACTERS ,
227+ hash_fragment_mode = HashFragmentMode .ALWAYS ,
228+ )
229+ key_sanitization = AlwaysHashStrategy ()
230+
231+ super ().__init__ (
232+ default_collection = default_collection ,
233+ collection_sanitization_strategy = collection_sanitization ,
234+ key_sanitization_strategy = key_sanitization ,
235+ )
216236
217237 @override
218238 async def _setup (self ) -> None :
@@ -222,32 +242,22 @@ async def _setup(self) -> None:
222242
223243 @override
224244 async def _setup_collection (self , * , collection : str ) -> None :
225- index_name = self ._sanitize_index_name (collection = collection )
245+ index_name = self ._get_index_name (collection = collection )
226246
227247 if await self ._client .options (ignore_status = 404 ).indices .exists (index = index_name ):
228248 return
229249
230250 _ = await self ._client .options (ignore_status = 404 ).indices .create (index = index_name , mappings = DEFAULT_MAPPING , settings = {})
231251
232- def _sanitize_index_name (self , collection : str ) -> str :
233- return sanitize_string (
234- value = self ._index_prefix + "-" + collection ,
235- replacement_character = "_" ,
236- max_length = MAX_INDEX_LENGTH ,
237- allowed_characters = ALLOWED_INDEX_CHARACTERS ,
238- )
252+ def _get_index_name (self , collection : str ) -> str :
253+ return self ._index_prefix + "-" + self ._sanitize_collection (collection = collection ).lower ()
239254
240- def _sanitize_document_id (self , key : str ) -> str :
241- return sanitize_string (
242- value = key ,
243- replacement_character = "_" ,
244- max_length = MAX_KEY_LENGTH ,
245- allowed_characters = ALLOWED_KEY_CHARACTERS ,
246- )
255+ def _get_document_id (self , key : str ) -> str :
256+ return self ._sanitize_key (key = key )
247257
248258 def _get_destination (self , * , collection : str , key : str ) -> tuple [str , str ]:
249- index_name : str = self ._sanitize_index_name (collection = collection )
250- document_id : str = self ._sanitize_document_id (key = key )
259+ index_name : str = self ._get_index_name (collection = collection )
260+ document_id : str = self ._get_document_id (key = key )
251261
252262 return index_name , document_id
253263
@@ -263,7 +273,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
263273 return None
264274
265275 try :
266- return self ._adapter .load_dict (data = source )
276+ return self ._serializer .load_dict (data = source )
267277 except DeserializationError :
268278 return None
269279
@@ -273,8 +283,8 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
273283 return []
274284
275285 # Use mget for efficient batch retrieval
276- index_name = self ._sanitize_index_name (collection = collection )
277- document_ids = [self ._sanitize_document_id (key = key ) for key in keys ]
286+ index_name = self ._get_index_name (collection = collection )
287+ document_ids = [self ._get_document_id (key = key ) for key in keys ]
278288 docs = [{"_id" : document_id } for document_id in document_ids ]
279289
280290 elasticsearch_response = await self ._client .options (ignore_status = 404 ).mget (index = index_name , docs = docs )
@@ -296,7 +306,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
296306 continue
297307
298308 try :
299- entries_by_id [doc_id ] = self ._adapter .load_dict (data = source )
309+ entries_by_id [doc_id ] = self ._serializer .load_dict (data = source )
300310 except DeserializationError as e :
301311 logger .error (
302312 "Failed to deserialize Elasticsearch document in batch operation" ,
@@ -324,10 +334,10 @@ async def _put_managed_entry(
324334 collection : str ,
325335 managed_entry : ManagedEntry ,
326336 ) -> None :
327- index_name : str = self ._sanitize_index_name (collection = collection )
328- document_id : str = self ._sanitize_document_id (key = key )
337+ index_name : str = self ._get_index_name (collection = collection )
338+ document_id : str = self ._get_document_id (key = key )
329339
330- document : dict [str , Any ] = self ._adapter .dump_dict (entry = managed_entry )
340+ document : dict [str , Any ] = self ._serializer .dump_dict (entry = managed_entry )
331341
332342 try :
333343 _ = await self ._client .index (
@@ -358,14 +368,14 @@ async def _put_managed_entries(
358368
359369 operations : list [dict [str , Any ]] = []
360370
361- index_name : str = self ._sanitize_index_name (collection = collection )
371+ index_name : str = self ._get_index_name (collection = collection )
362372
363373 for key , managed_entry in zip (keys , managed_entries , strict = True ):
364- document_id : str = self ._sanitize_document_id (key = key )
374+ document_id : str = self ._get_document_id (key = key )
365375
366376 index_action : dict [str , Any ] = new_bulk_action (action = "index" , index = index_name , document_id = document_id )
367377
368- document : dict [str , Any ] = self ._adapter .dump_dict (entry = managed_entry )
378+ document : dict [str , Any ] = self ._serializer .dump_dict (entry = managed_entry )
369379
370380 operations .extend ([index_action , document ])
371381
@@ -379,8 +389,8 @@ async def _put_managed_entries(
379389
380390 @override
381391 async def _delete_managed_entry (self , * , key : str , collection : str ) -> bool :
382- index_name : str = self ._sanitize_index_name (collection = collection )
383- document_id : str = self ._sanitize_document_id (key = key )
392+ index_name : str = self ._get_index_name (collection = collection )
393+ document_id : str = self ._get_document_id (key = key )
384394
385395 elasticsearch_response : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).delete (
386396 index = index_name , id = document_id
@@ -428,7 +438,7 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
428438 limit = min (limit or DEFAULT_PAGE_SIZE , PAGE_LIMIT )
429439
430440 result : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).search (
431- index = self ._sanitize_index_name (collection = collection ),
441+ index = self ._get_index_name (collection = collection ),
432442 fields = [{"key" : None }],
433443 body = {
434444 "query" : {
@@ -483,7 +493,7 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
483493 @override
484494 async def _delete_collection (self , * , collection : str ) -> bool :
485495 result : ObjectApiResponse [Any ] = await self ._client .options (ignore_status = 404 ).delete_by_query (
486- index = self ._sanitize_index_name (collection = collection ),
496+ index = self ._get_index_name (collection = collection ),
487497 body = {
488498 "query" : {
489499 "term" : {
0 commit comments