diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_encryption.py b/key-value/key-value-aio/tests/stores/wrappers/test_encryption.py index 6c6e5f5b..8ba60f17 100644 --- a/key-value/key-value-aio/tests/stores/wrappers/test_encryption.py +++ b/key-value/key-value-aio/tests/stores/wrappers/test_encryption.py @@ -175,3 +175,70 @@ def test_key_generation(): assert key_str_one != key_str_two assert key_str_one != key_str_three assert key_str_two != key_str_three + + +def test_fernet_with_source_material_and_salt(memory_store: MemoryStore): + """Test that FernetEncryptionWrapper works with source_material and salt.""" + wrapper = FernetEncryptionWrapper( + key_value=memory_store, + source_material="my-secret-key", + salt="my-unique-salt", + ) + assert wrapper is not None + + +def test_fernet_cannot_provide_fernet_with_source_material(memory_store: MemoryStore): + """Test that providing both fernet and source_material raises ValueError.""" + fernet = Fernet(key=Fernet.generate_key()) + with pytest.raises(ValueError, match="Cannot provide fernet together with source_material or salt"): + FernetEncryptionWrapper( + key_value=memory_store, + fernet=fernet, + source_material="test", + ) + + +def test_fernet_cannot_provide_fernet_with_salt(memory_store: MemoryStore): + """Test that providing both fernet and salt raises ValueError.""" + fernet = Fernet(key=Fernet.generate_key()) + with pytest.raises(ValueError, match="Cannot provide fernet together with source_material or salt"): + FernetEncryptionWrapper( + key_value=memory_store, + fernet=fernet, + salt="test", + ) + + +def test_fernet_must_provide_source_material(memory_store: MemoryStore): + """Test that not providing fernet or source_material raises ValueError.""" + with pytest.raises(ValueError, match="Must provide either fernet or source_material"): + FernetEncryptionWrapper(key_value=memory_store) + + +def test_fernet_must_provide_salt_with_source_material(memory_store: MemoryStore): + """Test that providing source_material without salt raises ValueError.""" + with pytest.raises(ValueError, match="Must provide a salt"): + FernetEncryptionWrapper( + key_value=memory_store, + source_material="test-source", + ) + + +def test_fernet_empty_source_material(memory_store: MemoryStore): + """Test that empty source_material raises ValueError.""" + with pytest.raises(ValueError, match="Must provide either fernet or source_material"): + FernetEncryptionWrapper( + key_value=memory_store, + source_material=" ", + salt="test", + ) + + +def test_fernet_empty_salt(memory_store: MemoryStore): + """Test that empty salt raises ValueError.""" + with pytest.raises(ValueError, match="Must provide a salt"): + FernetEncryptionWrapper( + key_value=memory_store, + source_material="test-source", + salt=" ", + ) diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_fallback.py b/key-value/key-value-aio/tests/stores/wrappers/test_fallback.py index 80bab234..b6a497a4 100644 --- a/key-value/key-value-aio/tests/stores/wrappers/test_fallback.py +++ b/key-value/key-value-aio/tests/stores/wrappers/test_fallback.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, SupportsFloat import pytest @@ -17,11 +17,43 @@ async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any msg = "Primary store unavailable" raise ConnectionError(msg) + @override + async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + msg = "Primary store unavailable" + raise ConnectionError(msg) + + @override + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + msg = "Primary store unavailable" + raise ConnectionError(msg) + + @override + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + msg = "Primary store unavailable" + raise ConnectionError(msg) + @override async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None): msg = "Primary store unavailable" raise ConnectionError(msg) + @override + async def put_many( + self, keys: Sequence[str], values: Sequence[Mapping[str, Any]], *, collection: str | None = None, ttl: SupportsFloat | None = None + ) -> None: + msg = "Primary store unavailable" + raise ConnectionError(msg) + + @override + async def delete(self, key: str, *, collection: str | None = None) -> bool: + msg = "Primary store unavailable" + raise ConnectionError(msg) + + @override + async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int: + msg = "Primary store unavailable" + raise ConnectionError(msg) + class TestFallbackWrapper(BaseStoreTests): @override @@ -77,3 +109,81 @@ async def test_write_to_fallback_enabled(self): # Verify it was written to fallback result = await fallback_store.get(collection="test", key="test") assert result == {"test": "value"} + + async def test_fallback_get_many(self): + primary_store = FailingStore() + fallback_store = MemoryStore() + wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) + + # Put data in fallback store + await fallback_store.put(collection="test", key="k1", value={"v": "1"}) + await fallback_store.put(collection="test", key="k2", value={"v": "2"}) + + # Should fall back for get_many + result = await wrapper.get_many(collection="test", keys=["k1", "k2"]) + assert result == [{"v": "1"}, {"v": "2"}] + + async def test_fallback_ttl(self): + primary_store = FailingStore() + fallback_store = MemoryStore() + wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) + + # Put data in fallback store with TTL + await fallback_store.put(collection="test", key="test", value={"v": "1"}, ttl=100) + + # Should fall back for ttl + value, ttl = await wrapper.ttl(collection="test", key="test") + assert value == {"v": "1"} + assert ttl is not None + + async def test_fallback_ttl_many(self): + primary_store = FailingStore() + fallback_store = MemoryStore() + wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) + + # Put data in fallback store + await fallback_store.put(collection="test", key="k1", value={"v": "1"}, ttl=100) + await fallback_store.put(collection="test", key="k2", value={"v": "2"}, ttl=200) + + # Should fall back for ttl_many + results = await wrapper.ttl_many(collection="test", keys=["k1", "k2"]) + assert results[0][0] == {"v": "1"} + assert results[1][0] == {"v": "2"} + + async def test_fallback_put_many_enabled(self): + primary_store = FailingStore() + fallback_store = MemoryStore() + wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=True) + + # Should fall back for put_many + await wrapper.put_many(collection="test", keys=["k1", "k2"], values=[{"v": "1"}, {"v": "2"}]) + + # Verify in fallback + assert await fallback_store.get(collection="test", key="k1") == {"v": "1"} + assert await fallback_store.get(collection="test", key="k2") == {"v": "2"} + + async def test_fallback_delete_enabled(self): + primary_store = FailingStore() + fallback_store = MemoryStore() + wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=True) + + # Put data in fallback + await fallback_store.put(collection="test", key="test", value={"v": "1"}) + + # Should fall back for delete + result = await wrapper.delete(collection="test", key="test") + assert result is True + assert await fallback_store.get(collection="test", key="test") is None + + async def test_fallback_delete_many_enabled(self): + primary_store = FailingStore() + fallback_store = MemoryStore() + wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=True) + + # Put data in fallback + await fallback_store.put(collection="test", key="k1", value={"v": "1"}) + await fallback_store.put(collection="test", key="k2", value={"v": "2"}) + + # Should fall back for delete_many + result = await wrapper.delete_many(collection="test", keys=["k1", "k2"]) + assert result == 2 diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_passthrough_cache.py b/key-value/key-value-aio/tests/stores/wrappers/test_passthrough_cache.py index 9bba9239..8f9a2660 100644 --- a/key-value/key-value-aio/tests/stores/wrappers/test_passthrough_cache.py +++ b/key-value/key-value-aio/tests/stores/wrappers/test_passthrough_cache.py @@ -28,3 +28,80 @@ async def cache_store(self, memory_store: MemoryStore) -> MemoryStore: async def store(self, primary_store: DiskStore, cache_store: MemoryStore) -> PassthroughCacheWrapper: primary_store._cache.clear() # pyright: ignore[reportPrivateUsage] return PassthroughCacheWrapper(primary_key_value=primary_store, cache_key_value=cache_store) + + async def test_ttl_caches_from_primary(self): + """Test that ttl retrieves from primary and caches the result.""" + primary_store = MemoryStore() + cache_store = MemoryStore() + wrapper = PassthroughCacheWrapper(primary_key_value=primary_store, cache_key_value=cache_store) + + # Put data in primary with TTL + await primary_store.put(collection="test", key="test", value={"v": "1"}, ttl=100) + + # Call ttl - should get from primary and cache it + value, ttl = await wrapper.ttl(collection="test", key="test") + assert value == {"v": "1"} + assert ttl is not None + + # Verify it's now in cache + cached_value = await cache_store.get(collection="test", key="test") + assert cached_value == {"v": "1"} + + async def test_ttl_returns_cached_value(self): + """Test that ttl returns cached value when available.""" + primary_store = MemoryStore() + cache_store = MemoryStore() + wrapper = PassthroughCacheWrapper(primary_key_value=primary_store, cache_key_value=cache_store) + + # Put data only in cache + await cache_store.put(collection="test", key="test", value={"v": "cached"}, ttl=100) + + # Call ttl - should return cached value + value, ttl = await wrapper.ttl(collection="test", key="test") + assert value == {"v": "cached"} + assert ttl is not None + + async def test_ttl_returns_none_for_missing(self): + """Test that ttl returns (None, None) for missing entries.""" + primary_store = MemoryStore() + cache_store = MemoryStore() + wrapper = PassthroughCacheWrapper(primary_key_value=primary_store, cache_key_value=cache_store) + + # Call ttl for non-existent key + value, ttl = await wrapper.ttl(collection="test", key="missing") + assert value is None + assert ttl is None + + async def test_ttl_many_caches_from_primary(self): + """Test that ttl_many retrieves from primary and caches results.""" + primary_store = MemoryStore() + cache_store = MemoryStore() + wrapper = PassthroughCacheWrapper(primary_key_value=primary_store, cache_key_value=cache_store) + + # Put data in primary with TTL + await primary_store.put(collection="test", key="k1", value={"v": "1"}, ttl=100) + await primary_store.put(collection="test", key="k2", value={"v": "2"}, ttl=200) + + # Call ttl_many - should get from primary and cache + results = await wrapper.ttl_many(collection="test", keys=["k1", "k2"]) + assert results[0][0] == {"v": "1"} + assert results[1][0] == {"v": "2"} + + # Verify in cache + assert await cache_store.get(collection="test", key="k1") == {"v": "1"} + assert await cache_store.get(collection="test", key="k2") == {"v": "2"} + + async def test_ttl_many_returns_cached_values(self): + """Test that ttl_many returns cached values when available.""" + primary_store = MemoryStore() + cache_store = MemoryStore() + wrapper = PassthroughCacheWrapper(primary_key_value=primary_store, cache_key_value=cache_store) + + # Put data in cache + await cache_store.put(collection="test", key="k1", value={"v": "cached1"}, ttl=100) + await cache_store.put(collection="test", key="k2", value={"v": "cached2"}, ttl=200) + + # Call ttl_many - should return cached values + results = await wrapper.ttl_many(collection="test", keys=["k1", "k2"]) + assert results[0][0] == {"v": "cached1"} + assert results[1][0] == {"v": "cached2"} diff --git a/key-value/key-value-shared/tests/utils/test_serialization.py b/key-value/key-value-shared/tests/utils/test_serialization.py index 1699a386..b350ab22 100644 --- a/key-value/key-value-shared/tests/utils/test_serialization.py +++ b/key-value/key-value-shared/tests/utils/test_serialization.py @@ -3,8 +3,9 @@ import pytest from inline_snapshot import snapshot +from key_value.shared.errors import DeserializationError, SerializationError from key_value.shared.utils.managed_entry import ManagedEntry -from key_value.shared.utils.serialization import BasicSerializationAdapter +from key_value.shared.utils.serialization import BasicSerializationAdapter, key_must_be, parse_datetime_str FIXED_DATETIME_ONE = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) FIXED_DATETIME_ONE_ISOFORMAT = FIXED_DATETIME_ONE.isoformat() @@ -80,3 +81,92 @@ def test_entry_two(self, adapter: BasicSerializationAdapter): assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_TWO)) == snapshot(TEST_ENTRY_TWO) assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_TWO)) == snapshot(TEST_ENTRY_TWO) + + def test_dump_dict_with_key_and_collection(self, adapter: BasicSerializationAdapter): + """Test dump_dict includes key and collection when provided.""" + result = adapter.dump_dict(entry=TEST_ENTRY_ONE, key="my-key", collection="my-collection") + assert result["key"] == "my-key" + assert result["collection"] == "my-collection" + + def test_dump_dict_with_datetime_format(self): + """Test dump_dict with datetime format instead of isoformat.""" + adapter = BasicSerializationAdapter(date_format="datetime") + result = adapter.dump_dict(entry=TEST_ENTRY_ONE) + assert result["created_at"] == FIXED_DATETIME_ONE + assert result["expires_at"] == FIXED_DATETIME_ONE_PLUS_10_SECONDS + + def test_load_dict_with_datetime_format(self): + """Test load_dict with datetime format instead of isoformat.""" + adapter = BasicSerializationAdapter(date_format="datetime") + data = { + "created_at": FIXED_DATETIME_ONE, + "expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS, + "value": TEST_DATA_ONE, + } + result = adapter.load_dict(data=data) + assert result.created_at == FIXED_DATETIME_ONE + assert result.expires_at == FIXED_DATETIME_ONE_PLUS_10_SECONDS + + def test_dump_json_with_datetime_format_raises_error(self): + """Test dump_json raises error when date_format is datetime.""" + adapter = BasicSerializationAdapter(date_format="datetime") + with pytest.raises(SerializationError, match="dump_json is incompatible"): + adapter.dump_json(entry=TEST_ENTRY_ONE) + + def test_load_dict_with_string_value(self, adapter: BasicSerializationAdapter): + """Test load_dict with value as JSON string.""" + data = { + "created_at": FIXED_DATETIME_ONE_ISOFORMAT, + "expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT, + "value": '{"key": "value"}', + } + result = adapter.load_dict(data=data) + assert result.value == {"key": "value"} + + def test_load_dict_missing_value_raises_error(self, adapter: BasicSerializationAdapter): + """Test load_dict raises error when value is missing.""" + data = { + "created_at": FIXED_DATETIME_ONE_ISOFORMAT, + "expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT, + } + with pytest.raises(DeserializationError, match="Value field not found"): + adapter.load_dict(data=data) + + def test_load_dict_invalid_value_type_raises_error(self, adapter: BasicSerializationAdapter): + """Test load_dict raises error when value is not string or dict.""" + data = { + "created_at": FIXED_DATETIME_ONE_ISOFORMAT, + "expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT, + "value": 12345, + } + with pytest.raises(DeserializationError, match="Value field is not a string or dictionary"): + adapter.load_dict(data=data) + + +class TestKeyMustBe: + def test_key_missing(self): + """Test key_must_be returns None when key is missing.""" + result = key_must_be({"other": "value"}, key="missing", expected_type=str) + assert result is None + + def test_key_wrong_type(self): + """Test key_must_be raises TypeError when type is wrong.""" + with pytest.raises(TypeError, match="created_at must be a str"): + key_must_be({"created_at": 12345}, key="created_at", expected_type=str) + + def test_key_correct_type(self): + """Test key_must_be returns value when type is correct.""" + result = key_must_be({"created_at": "2025-01-01"}, key="created_at", expected_type=str) + assert result == "2025-01-01" + + +class TestParseDatetimeStr: + def test_valid_datetime(self): + """Test parse_datetime_str with valid datetime string.""" + result = parse_datetime_str("2025-01-01T00:00:00+00:00") + assert result == FIXED_DATETIME_ONE + + def test_invalid_datetime(self): + """Test parse_datetime_str raises error for invalid string.""" + with pytest.raises(DeserializationError, match="Invalid datetime string"): + parse_datetime_str("not-a-datetime") diff --git a/key-value/key-value-shared/tests/utils/test_time_to_live.py b/key-value/key-value-shared/tests/utils/test_time_to_live.py index 0f69ee9e..5d6aea28 100644 --- a/key-value/key-value-shared/tests/utils/test_time_to_live.py +++ b/key-value/key-value-shared/tests/utils/test_time_to_live.py @@ -5,7 +5,7 @@ import pytest from key_value.shared.errors.key_value import InvalidTTLError -from key_value.shared.utils.time_to_live import prepare_ttl +from key_value.shared.utils.time_to_live import prepare_entry_timestamps, prepare_ttl, try_parse_datetime_str FIXED_DATETIME = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -58,3 +58,49 @@ def test_prepare_ttl(t: Any, expected: int | float | None): def test_prepare_ttl_invalid(t: Any): with pytest.raises(InvalidTTLError): prepare_ttl(t) + + +def test_prepare_ttl_zero(): + """Test that zero TTL raises InvalidTTLError.""" + with pytest.raises(InvalidTTLError): + prepare_ttl(0) + + +def test_prepare_ttl_negative(): + """Test that negative TTL raises InvalidTTLError.""" + with pytest.raises(InvalidTTLError): + prepare_ttl(-100) + + +class TestTryParseDatetimeStr: + def test_valid_datetime_string(self): + """Test parsing valid datetime string.""" + result = try_parse_datetime_str("2025-01-01T00:00:00+00:00") + assert result == FIXED_DATETIME + + def test_invalid_datetime_string(self): + """Test parsing invalid datetime string returns None.""" + result = try_parse_datetime_str("not-a-datetime") + assert result is None + + def test_non_string_returns_none(self): + """Test non-string values return None.""" + assert try_parse_datetime_str(12345) is None + assert try_parse_datetime_str(None) is None + assert try_parse_datetime_str({"key": "value"}) is None + + +class TestPrepareEntryTimestamps: + def test_with_ttl(self): + """Test prepare_entry_timestamps with a TTL.""" + created_at, ttl, expires_at = prepare_entry_timestamps(ttl=100) + assert ttl == 100.0 + assert expires_at is not None + assert expires_at > created_at + + def test_without_ttl(self): + """Test prepare_entry_timestamps without a TTL.""" + created_at, ttl, expires_at = prepare_entry_timestamps(ttl=None) + assert ttl is None + assert expires_at is None + assert created_at is not None