Skip to content

Commit 5272fce

Browse files
feat: Add SanitizationStrategy (#208)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent 782d067 commit 5272fce

File tree

30 files changed

+887
-289
lines changed

30 files changed

+887
-289
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from key_value.shared.errors import StoreSetupError
1515
from key_value.shared.type_checking.bear_spray import bear_enforce
1616
from key_value.shared.utils.managed_entry import ManagedEntry
17+
from key_value.shared.utils.sanitization import PassthroughStrategy, SanitizationStrategy
1718
from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter
1819
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
1920
from typing_extensions import Self, override
@@ -69,6 +70,8 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
6970
_setup_collection_complete: defaultdict[str, bool]
7071

7172
_serialization_adapter: SerializationAdapter
73+
_key_sanitization_strategy: SanitizationStrategy
74+
_collection_sanitization_strategy: SanitizationStrategy
7275

7376
_seed: FROZEN_SEED_DATA_TYPE
7477

@@ -78,13 +81,17 @@ def __init__(
7881
self,
7982
*,
8083
serialization_adapter: SerializationAdapter | None = None,
84+
key_sanitization_strategy: SanitizationStrategy | None = None,
85+
collection_sanitization_strategy: SanitizationStrategy | None = None,
8186
default_collection: str | None = None,
8287
seed: SEED_DATA_TYPE | None = None,
8388
) -> None:
8489
"""Initialize the managed key-value store.
8590
8691
Args:
8792
serialization_adapter: The serialization adapter to use for the store.
93+
key_sanitization_strategy: The sanitization strategy to use for keys.
94+
collection_sanitization_strategy: The sanitization strategy to use for collections.
8895
default_collection: The default collection to use if no collection is provided.
8996
Defaults to "default_collection".
9097
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
@@ -103,6 +110,9 @@ def __init__(
103110

104111
self._serialization_adapter = serialization_adapter or BasicSerializationAdapter()
105112

113+
self._key_sanitization_strategy = key_sanitization_strategy or PassthroughStrategy()
114+
self._collection_sanitization_strategy = collection_sanitization_strategy or PassthroughStrategy()
115+
106116
if not hasattr(self, "_stable_api"):
107117
self._stable_api = False
108118

@@ -117,6 +127,17 @@ async def _setup(self) -> None:
117127
async def _setup_collection(self, *, collection: str) -> None:
118128
"""Initialize the collection (called once before first use of the collection)."""
119129

130+
def _sanitize_collection_and_key(self, collection: str, key: str) -> tuple[str, str]:
131+
return self._sanitize_collection(collection=collection), self._sanitize_key(key=key)
132+
133+
def _sanitize_collection(self, collection: str) -> str:
134+
self._collection_sanitization_strategy.validate(value=collection)
135+
return self._collection_sanitization_strategy.sanitize(value=collection)
136+
137+
def _sanitize_key(self, key: str) -> str:
138+
self._key_sanitization_strategy.validate(value=key)
139+
return self._key_sanitization_strategy.sanitize(value=key)
140+
120141
async def _seed_store(self) -> None:
121142
"""Seed the store with the data from the seed."""
122143
for collection, items in self._seed.items():

key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from elastic_transport import SerializationError as ElasticsearchSerializationError
88
from key_value.shared.errors import DeserializationError, SerializationError
99
from key_value.shared.utils.managed_entry import ManagedEntry
10+
from key_value.shared.utils.sanitization import AlwaysHashStrategy, HashFragmentMode, HybridSanitizationStrategy
1011
from key_value.shared.utils.sanitize import (
1112
ALPHANUMERIC_CHARACTERS,
1213
LOWERCASE_ALPHABET,
1314
NUMBERS,
14-
sanitize_string,
15+
UPPERCASE_ALPHABET,
1516
)
1617
from key_value.shared.utils.serialization import SerializationAdapter
1718
from key_value.shared.utils.time_to_live import now_as_epoch
@@ -145,7 +146,7 @@ class ElasticsearchStore(
145146

146147
_native_storage: bool
147148

148-
_adapter: SerializationAdapter
149+
_serializer: SerializationAdapter
149150

150151
@overload
151152
def __init__(
@@ -207,12 +208,31 @@ def __init__(
207208
LessCapableJsonSerializer.install_default_serializer(client=self._client)
208209
LessCapableNdjsonSerializer.install_serializer(client=self._client)
209210

210-
self._index_prefix = index_prefix
211+
self._index_prefix = index_prefix.lower()
211212
self._native_storage = native_storage
212213
self._is_serverless = False
213-
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)
214214

215-
super().__init__(default_collection=default_collection)
215+
# We have 240 characters to work with
216+
# We need to account for the index prefix and the hyphen.
217+
max_index_length = MAX_INDEX_LENGTH - (len(self._index_prefix) + 1)
218+
219+
self._serializer = ElasticsearchSerializationAdapter(native_storage=native_storage)
220+
221+
# We allow uppercase through the sanitizer so we can lowercase them instead of them
222+
# all turning into underscores.
223+
collection_sanitization = HybridSanitizationStrategy(
224+
replacement_character="_",
225+
max_length=max_index_length,
226+
allowed_characters=UPPERCASE_ALPHABET + ALLOWED_INDEX_CHARACTERS,
227+
hash_fragment_mode=HashFragmentMode.ALWAYS,
228+
)
229+
key_sanitization = AlwaysHashStrategy()
230+
231+
super().__init__(
232+
default_collection=default_collection,
233+
collection_sanitization_strategy=collection_sanitization,
234+
key_sanitization_strategy=key_sanitization,
235+
)
216236

217237
@override
218238
async def _setup(self) -> None:
@@ -222,32 +242,22 @@ async def _setup(self) -> None:
222242

223243
@override
224244
async def _setup_collection(self, *, collection: str) -> None:
225-
index_name = self._sanitize_index_name(collection=collection)
245+
index_name = self._get_index_name(collection=collection)
226246

227247
if await self._client.options(ignore_status=404).indices.exists(index=index_name):
228248
return
229249

230250
_ = await self._client.options(ignore_status=404).indices.create(index=index_name, mappings=DEFAULT_MAPPING, settings={})
231251

232-
def _sanitize_index_name(self, collection: str) -> str:
233-
return sanitize_string(
234-
value=self._index_prefix + "-" + collection,
235-
replacement_character="_",
236-
max_length=MAX_INDEX_LENGTH,
237-
allowed_characters=ALLOWED_INDEX_CHARACTERS,
238-
)
252+
def _get_index_name(self, collection: str) -> str:
253+
return self._index_prefix + "-" + self._sanitize_collection(collection=collection).lower()
239254

240-
def _sanitize_document_id(self, key: str) -> str:
241-
return sanitize_string(
242-
value=key,
243-
replacement_character="_",
244-
max_length=MAX_KEY_LENGTH,
245-
allowed_characters=ALLOWED_KEY_CHARACTERS,
246-
)
255+
def _get_document_id(self, key: str) -> str:
256+
return self._sanitize_key(key=key)
247257

248258
def _get_destination(self, *, collection: str, key: str) -> tuple[str, str]:
249-
index_name: str = self._sanitize_index_name(collection=collection)
250-
document_id: str = self._sanitize_document_id(key=key)
259+
index_name: str = self._get_index_name(collection=collection)
260+
document_id: str = self._get_document_id(key=key)
251261

252262
return index_name, document_id
253263

@@ -263,7 +273,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
263273
return None
264274

265275
try:
266-
return self._adapter.load_dict(data=source)
276+
return self._serializer.load_dict(data=source)
267277
except DeserializationError:
268278
return None
269279

@@ -273,8 +283,8 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
273283
return []
274284

275285
# Use mget for efficient batch retrieval
276-
index_name = self._sanitize_index_name(collection=collection)
277-
document_ids = [self._sanitize_document_id(key=key) for key in keys]
286+
index_name = self._get_index_name(collection=collection)
287+
document_ids = [self._get_document_id(key=key) for key in keys]
278288
docs = [{"_id": document_id} for document_id in document_ids]
279289

280290
elasticsearch_response = await self._client.options(ignore_status=404).mget(index=index_name, docs=docs)
@@ -296,7 +306,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
296306
continue
297307

298308
try:
299-
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
309+
entries_by_id[doc_id] = self._serializer.load_dict(data=source)
300310
except DeserializationError as e:
301311
logger.error(
302312
"Failed to deserialize Elasticsearch document in batch operation",
@@ -324,10 +334,10 @@ async def _put_managed_entry(
324334
collection: str,
325335
managed_entry: ManagedEntry,
326336
) -> None:
327-
index_name: str = self._sanitize_index_name(collection=collection)
328-
document_id: str = self._sanitize_document_id(key=key)
337+
index_name: str = self._get_index_name(collection=collection)
338+
document_id: str = self._get_document_id(key=key)
329339

330-
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
340+
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry)
331341

332342
try:
333343
_ = await self._client.index(
@@ -358,14 +368,14 @@ async def _put_managed_entries(
358368

359369
operations: list[dict[str, Any]] = []
360370

361-
index_name: str = self._sanitize_index_name(collection=collection)
371+
index_name: str = self._get_index_name(collection=collection)
362372

363373
for key, managed_entry in zip(keys, managed_entries, strict=True):
364-
document_id: str = self._sanitize_document_id(key=key)
374+
document_id: str = self._get_document_id(key=key)
365375

366376
index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)
367377

368-
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)
378+
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry)
369379

370380
operations.extend([index_action, document])
371381

@@ -379,8 +389,8 @@ async def _put_managed_entries(
379389

380390
@override
381391
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
382-
index_name: str = self._sanitize_index_name(collection=collection)
383-
document_id: str = self._sanitize_document_id(key=key)
392+
index_name: str = self._get_index_name(collection=collection)
393+
document_id: str = self._get_document_id(key=key)
384394

385395
elasticsearch_response: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete(
386396
index=index_name, id=document_id
@@ -428,7 +438,7 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
428438
limit = min(limit or DEFAULT_PAGE_SIZE, PAGE_LIMIT)
429439

430440
result: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).search(
431-
index=self._sanitize_index_name(collection=collection),
441+
index=self._get_index_name(collection=collection),
432442
fields=[{"key": None}],
433443
body={
434444
"query": {
@@ -483,7 +493,7 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
483493
@override
484494
async def _delete_collection(self, *, collection: str) -> bool:
485495
result: ObjectApiResponse[Any] = await self._client.options(ignore_status=404).delete_by_query(
486-
index=self._sanitize_index_name(collection=collection),
496+
index=self._get_index_name(collection=collection),
487497
body={
488498
"query": {
489499
"term": {

key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from key_value.shared.utils.compound import compound_key
44
from key_value.shared.utils.managed_entry import ManagedEntry
5-
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string
5+
from key_value.shared.utils.sanitization import HybridSanitizationStrategy
6+
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS
67
from typing_extensions import override
78

89
from key_value.aio.stores.base import BaseStore
@@ -15,11 +16,9 @@
1516
raise ImportError(msg) from e
1617

1718
DEFAULT_KEYCHAIN_SERVICE = "py-key-value"
18-
MAX_KEY_LENGTH = 256
19-
ALLOWED_KEY_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
2019

21-
MAX_COLLECTION_LENGTH = 256
22-
ALLOWED_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
20+
MAX_KEY_COLLECTION_LENGTH = 256
21+
ALLOWED_KEY_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS
2322

2423

2524
class KeyringStore(BaseStore):
@@ -48,25 +47,19 @@ def __init__(
4847
"""
4948
self._service_name = service_name
5049

51-
super().__init__(default_collection=default_collection)
52-
53-
def _sanitize_collection_name(self, collection: str) -> str:
54-
return sanitize_string(
55-
value=collection,
56-
max_length=MAX_COLLECTION_LENGTH,
57-
allowed_characters=ALLOWED_COLLECTION_CHARACTERS,
50+
sanitization_strategy = HybridSanitizationStrategy(
51+
replacement_character="_", max_length=MAX_KEY_COLLECTION_LENGTH, allowed_characters=ALLOWED_KEY_COLLECTION_CHARACTERS
5852
)
5953

60-
def _sanitize_key(self, key: str) -> str:
61-
return sanitize_string(
62-
value=key,
63-
max_length=MAX_KEY_LENGTH,
64-
allowed_characters=ALLOWED_KEY_CHARACTERS,
54+
super().__init__(
55+
default_collection=default_collection,
56+
collection_sanitization_strategy=sanitization_strategy,
57+
key_sanitization_strategy=sanitization_strategy,
6558
)
6659

6760
@override
6861
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
69-
sanitized_collection = self._sanitize_collection_name(collection=collection)
62+
sanitized_collection = self._sanitize_collection(collection=collection)
7063
sanitized_key = self._sanitize_key(key=key)
7164

7265
combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
@@ -83,7 +76,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
8376

8477
@override
8578
async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None:
86-
sanitized_collection = self._sanitize_collection_name(collection=collection)
79+
sanitized_collection = self._sanitize_collection(collection=collection)
8780
sanitized_key = self._sanitize_key(key=key)
8881

8982
combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)
@@ -94,7 +87,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:
9487

9588
@override
9689
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
97-
sanitized_collection = self._sanitize_collection_name(collection=collection)
90+
sanitized_collection = self._sanitize_collection(collection=collection)
9891
sanitized_key = self._sanitize_key(key=key)
9992

10093
combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)

key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from key_value.shared.utils.compound import compound_key
66
from key_value.shared.utils.managed_entry import ManagedEntry
7+
from key_value.shared.utils.sanitization import HashExcessLengthStrategy
78
from typing_extensions import override
89

910
from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseStore
@@ -46,7 +47,12 @@ def __init__(
4647
"""
4748
self._client = client or Client(host=host, port=port)
4849

49-
super().__init__(default_collection=default_collection)
50+
sanitization_strategy = HashExcessLengthStrategy(max_length=MAX_KEY_LENGTH)
51+
52+
super().__init__(
53+
default_collection=default_collection,
54+
key_sanitization_strategy=sanitization_strategy,
55+
)
5056

5157
def sanitize_key(self, key: str) -> str:
5258
if len(key) > MAX_KEY_LENGTH:

0 commit comments

Comments
 (0)