Skip to content

Commit 17add25

Browse files
strawgategithub-actions[bot]claude
authored
Implement client ownership split for store lifecycle management (#245)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: William Easton <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 3ab83ea commit 17add25

File tree

34 files changed

+480
-396
lines changed

34 files changed

+480
-396
lines changed

key-value/key-value-aio/src/key_value/aio/stores/base.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from asyncio.locks import Lock
77
from collections import defaultdict
88
from collections.abc import Mapping, Sequence
9+
from contextlib import AsyncExitStack
910
from datetime import datetime
1011
from types import MappingProxyType, TracebackType
1112
from typing import Any, SupportsFloat
@@ -85,6 +86,7 @@ def __init__(
8586
collection_sanitization_strategy: SanitizationStrategy | None = None,
8687
default_collection: str | None = None,
8788
seed: SEED_DATA_TYPE | None = None,
89+
stable_api: bool = False,
8890
) -> None:
8991
"""Initialize the managed key-value store.
9092
@@ -97,6 +99,8 @@ def __init__(
9799
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
98100
Seeding occurs once during store initialization (when the store is first entered or when the
99101
first operation is performed on the store).
102+
stable_api: Whether this store implementation has a stable API. If False, a warning will be issued.
103+
Defaults to False.
100104
"""
101105

102106
self._setup_complete = False
@@ -113,8 +117,7 @@ def __init__(
113117
self._key_sanitization_strategy = key_sanitization_strategy or PassthroughStrategy()
114118
self._collection_sanitization_strategy = collection_sanitization_strategy or PassthroughStrategy()
115119

116-
if not hasattr(self, "_stable_api"):
117-
self._stable_api = False
120+
self._stable_api = stable_api
118121

119122
if not self._stable_api:
120123
self._warn_about_stability()
@@ -425,24 +428,74 @@ async def _get_collection_keys(self, *, collection: str, limit: int | None = Non
425428

426429

427430
class BaseContextManagerStore(BaseStore, ABC):
428-
"""An abstract base class for context manager stores."""
431+
"""An abstract base class for context manager stores.
432+
433+
Stores that accept a client parameter should pass `client_provided_by_user=True` to
434+
the constructor. This ensures the store does not manage the lifecycle of user-provided
435+
clients (i.e., does not close them).
436+
437+
The base class provides an AsyncExitStack that stores can use to register cleanup
438+
callbacks. Stores should add their cleanup operations to the exit stack as needed.
439+
The base class handles entering and exiting the exit stack.
440+
"""
441+
442+
_client_provided_by_user: bool
443+
_exit_stack: AsyncExitStack
444+
_exit_stack_entered: bool
445+
446+
def __init__(self, *, client_provided_by_user: bool = False, **kwargs: Any) -> None:
447+
"""Initialize the context manager store with client ownership configuration.
448+
449+
Args:
450+
client_provided_by_user: Whether the client was provided by the user. If True,
451+
the store will not manage the client's lifecycle (will not close it).
452+
Defaults to False.
453+
**kwargs: Additional arguments to pass to the base store constructor.
454+
"""
455+
self._client_provided_by_user = client_provided_by_user
456+
self._exit_stack = AsyncExitStack()
457+
self._exit_stack_entered = False
458+
super().__init__(**kwargs)
459+
460+
async def _ensure_exit_stack_entered(self) -> None:
461+
"""Ensure the exit stack has been entered."""
462+
if not self._exit_stack_entered:
463+
await self._exit_stack.__aenter__()
464+
self._exit_stack_entered = True
429465

430466
async def __aenter__(self) -> Self:
467+
# Enter the exit stack
468+
await self._ensure_exit_stack_entered()
431469
await self.setup()
432470
return self
433471

434472
async def __aexit__(
435473
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
436-
) -> None:
437-
await self._close()
474+
) -> bool | None:
475+
# Close the exit stack, which handles all cleanup
476+
if self._exit_stack_entered:
477+
result = await self._exit_stack.__aexit__(exc_type, exc_value, traceback)
478+
self._exit_stack_entered = False
479+
480+
return result
481+
return None
438482

439483
async def close(self) -> None:
440-
await self._close()
484+
# Close the exit stack if it has been entered
485+
if self._exit_stack_entered:
486+
await self._exit_stack.aclose()
487+
self._exit_stack_entered = False
441488

442-
@abstractmethod
443-
async def _close(self) -> None:
444-
"""Close the store."""
445-
...
489+
async def setup(self) -> None:
490+
"""Initialize the store if not already initialized.
491+
492+
This override ensures the exit stack is entered before the store's _setup()
493+
method is called, allowing stores to register cleanup callbacks during setup.
494+
"""
495+
# Ensure exit stack is entered
496+
await self._ensure_exit_stack_entered()
497+
# Call parent setup
498+
await super().setup()
446499

447500

448501
class BaseEnumerateCollectionsStore(BaseStore, AsyncEnumerateCollectionsProtocol, ABC):

key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,17 @@ def default_disk_cache_factory(collection: str) -> Cache:
9999

100100
self._cache = {}
101101

102-
self._stable_api = True
103102
self._serialization_adapter = BasicSerializationAdapter()
104103

105-
super().__init__(default_collection=default_collection)
104+
super().__init__(
105+
default_collection=default_collection,
106+
stable_api=True,
107+
)
108+
109+
@override
110+
async def _setup(self) -> None:
111+
"""Register cache cleanup."""
112+
self._exit_stack.callback(self._sync_close)
106113

107114
@override
108115
async def _setup_collection(self, *, collection: str) -> None:
@@ -146,9 +153,5 @@ def _sync_close(self) -> None:
146153
for cache in self._cache.values():
147154
cache.close()
148155

149-
@override
150-
async def _close(self) -> None:
151-
self._sync_close()
152-
153156
def __del__(self) -> None:
154157
self._sync_close()

key-value/key-value-aio/src/key_value/aio/stores/disk/store.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def __init__(
5050
"""Initialize the disk store.
5151
5252
Args:
53-
disk_cache: An existing diskcache Cache instance to use.
53+
disk_cache: An existing diskcache Cache instance to use. If provided, the store will
54+
not manage the cache's lifecycle (will not close it). The caller is responsible
55+
for managing the cache's lifecycle.
5456
directory: The directory to use for the disk store.
5557
max_size: The maximum size of the disk store.
5658
default_collection: The default collection to use if no collection is provided.
@@ -63,6 +65,9 @@ def __init__(
6365
msg = "Either disk_cache or directory must be provided"
6466
raise ValueError(msg)
6567

68+
client_provided = disk_cache is not None
69+
self._client_provided_by_user = client_provided
70+
6671
if disk_cache:
6772
self._cache = disk_cache
6873
elif directory:
@@ -75,9 +80,17 @@ def __init__(
7580
else:
7681
self._cache = Cache(directory=directory, eviction_policy="none")
7782

78-
self._stable_api = True
83+
super().__init__(
84+
default_collection=default_collection,
85+
client_provided_by_user=client_provided,
86+
stable_api=True,
87+
)
7988

80-
super().__init__(default_collection=default_collection)
89+
@override
90+
async def _setup(self) -> None:
91+
"""Register cache cleanup if we own the cache."""
92+
if not self._client_provided_by_user:
93+
self._exit_stack.callback(self._cache.close)
8194

8295
@override
8396
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
@@ -119,9 +132,6 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
119132

120133
return self._cache.delete(key=combo_key, retry=True)
121134

122-
@override
123-
async def _close(self) -> None:
124-
self._cache.close()
125-
126135
def __del__(self) -> None:
127-
self._cache.close()
136+
if not getattr(self, "_client_provided_by_user", False) and hasattr(self, "_cache"):
137+
self._cache.close()

key-value/key-value-aio/src/key_value/aio/stores/duckdb/store.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,9 @@ class DuckDBStore(BaseContextManagerStore, BaseStore):
7171
7272
Values are stored in a JSON column as native dicts, allowing direct SQL queries
7373
on the stored data for analytics and reporting.
74-
75-
Note on connection ownership: When you provide an existing connection, the store
76-
will take ownership and close it when the store is closed or garbage collected.
77-
If you need to reuse a connection, create separate DuckDB connections for each store.
7874
"""
7975

8076
_connection: duckdb.DuckDBPyConnection
81-
_is_closed: bool
82-
_owns_connection: bool
8377
_adapter: SerializationAdapter
8478
_table_name: str
8579

@@ -94,9 +88,8 @@ def __init__(
9488
) -> None:
9589
"""Initialize the DuckDB store with an existing connection.
9690
97-
Warning: The store will take ownership of the connection and close it
98-
when the store is closed or garbage collected. If you need to reuse
99-
a connection, create separate DuckDB connections for each store.
91+
Note: If you provide a connection, the store will NOT manage its lifecycle (will not
92+
close it). The caller is responsible for managing the connection's lifecycle.
10093
10194
Args:
10295
connection: An existing DuckDB connection to use.
@@ -135,7 +128,9 @@ def __init__(
135128
"""Initialize the DuckDB store.
136129
137130
Args:
138-
connection: An existing DuckDB connection to use.
131+
connection: An existing DuckDB connection to use. If provided, the store will NOT
132+
manage its lifecycle (will not close it). The caller is responsible for managing
133+
the connection's lifecycle.
139134
database_path: Path to the database file. If None or ':memory:', uses in-memory database.
140135
table_name: Name of the table to store key-value entries. Defaults to "kv_entries".
141136
default_collection: The default collection to use if no collection is provided.
@@ -145,9 +140,10 @@ def __init__(
145140
msg = "Provide only one of connection or database_path"
146141
raise ValueError(msg)
147142

143+
client_provided = connection is not None
144+
148145
if connection is not None:
149146
self._connection = connection
150-
self._owns_connection = True # We take ownership even of provided connections
151147
else:
152148
# Convert Path to string if needed
153149
if isinstance(database_path, Path):
@@ -158,19 +154,21 @@ def __init__(
158154
self._connection = duckdb.connect(":memory:")
159155
else:
160156
self._connection = duckdb.connect(database=database_path)
161-
self._owns_connection = True
162157

163-
self._is_closed = False
164158
self._adapter = DuckDBSerializationAdapter()
165159

166160
# Validate table name to prevent SQL injection
167161
if not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", table_name):
168162
msg = "Table name must start with a letter or underscore and contain only letters, digits, or underscores"
169163
raise ValueError(msg)
170164
self._table_name = table_name
171-
self._stable_api = False
172165

173-
super().__init__(default_collection=default_collection, seed=seed)
166+
super().__init__(
167+
default_collection=default_collection,
168+
seed=seed,
169+
client_provided_by_user=client_provided,
170+
stable_api=False,
171+
)
174172

175173
def _get_create_table_sql(self) -> str:
176174
"""Generate SQL for creating the key-value entries table.
@@ -263,6 +261,10 @@ async def _setup(self) -> None:
263261
- Metadata queries without JSON deserialization
264262
- Native JSON column support for rich querying capabilities
265263
"""
264+
# Register connection cleanup if we own the connection
265+
if not self._client_provided_by_user:
266+
self._exit_stack.callback(self._connection.close)
267+
266268
# Create the main table for storing key-value entries
267269
self._connection.execute(self._get_create_table_sql())
268270

@@ -279,10 +281,6 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
279281
Reconstructs the ManagedEntry from value column and metadata columns
280282
using the serialization adapter.
281283
"""
282-
if self._is_closed:
283-
msg = "Cannot operate on closed DuckDBStore"
284-
raise RuntimeError(msg)
285-
286284
result = self._connection.execute(
287285
self._get_select_sql(),
288286
[collection, key],
@@ -323,10 +321,6 @@ async def _put_managed_entry(
323321
Uses the serialization adapter to convert the ManagedEntry to the
324322
appropriate storage format.
325323
"""
326-
if self._is_closed:
327-
msg = "Cannot operate on closed DuckDBStore"
328-
raise RuntimeError(msg)
329-
330324
# Ensure that the value is serializable to JSON
331325
_ = managed_entry.value_as_json
332326

@@ -349,10 +343,6 @@ async def _put_managed_entry(
349343
@override
350344
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
351345
"""Delete a managed entry by key from the specified collection."""
352-
if self._is_closed:
353-
msg = "Cannot operate on closed DuckDBStore"
354-
raise RuntimeError(msg)
355-
356346
result = self._connection.execute(
357347
self._get_delete_sql(),
358348
[collection, key],
@@ -361,20 +351,3 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
361351
# Check if any rows were deleted by counting returned rows
362352
deleted_rows = result.fetchall()
363353
return len(deleted_rows) > 0
364-
365-
@override
366-
async def _close(self) -> None:
367-
"""Close the DuckDB connection."""
368-
if not self._is_closed and self._owns_connection:
369-
self._connection.close()
370-
self._is_closed = True
371-
372-
def __del__(self) -> None:
373-
"""Clean up the DuckDB connection on deletion."""
374-
try:
375-
if not self._is_closed and self._owns_connection and hasattr(self, "_connection"):
376-
self._connection.close()
377-
self._is_closed = True
378-
except Exception: # noqa: S110
379-
# Suppress errors during cleanup to avoid issues during interpreter shutdown
380-
pass

0 commit comments

Comments
 (0)