11import asyncio
22import sqlite3
3- import threading
43from collections .abc import Callable , Sequence
54from datetime import datetime
65from pathlib import Path
7- from typing import TYPE_CHECKING , TypeVar
6+ from typing import TYPE_CHECKING , Any , TypeAlias , TypeVar
87from zoneinfo import ZoneInfo
98
109from typing_extensions import override
1110
12- from jobify ._internal .common .constants import JobStatus
11+ from jobify ._internal .common .constants import EMPTY , JobStatus
1312from jobify ._internal .storage .base import (
1413 ScheduledJob ,
1514 Storage ,
2120
2221 from jobify ._internal .common .types import LoopFactory
2322
23+
2424CREATE_SCHEDULED_TABLE_QUERY = """
2525CREATE TABLE IF NOT EXISTS {} (
2626 job_id TEXT PRIMARY KEY,
5959
6060ReturnT = TypeVar ("ReturnT" )
6161
62+ _Callback = Callable [[sqlite3 .Connection ], ReturnT ]
63+ _AsyncQueue : TypeAlias = asyncio .Queue [
64+ tuple [_Callback [Any ], asyncio .Future [Any ]],
65+ ]
66+ _STOP : Any = object ()
67+
6268
6369class SQLiteStorage (Storage ):
6470 def __init__ (
@@ -67,6 +73,7 @@ def __init__(
6773 * ,
6874 table_name : str = "jobify_schedules" ,
6975 timeout : float = 20.0 ,
76+ max_queue_size : int = 1024 ,
7077 ) -> None :
7178 validate_table_name (table_name )
7279 self .database : Path = (
@@ -77,8 +84,10 @@ def __init__(
7784 self .tz : ZoneInfo = ZoneInfo ("UTC" )
7885 self .getloop : LoopFactory = asyncio ._get_running_loop
7986 self .threadpool : ThreadPoolExecutor | None = None
80- self ._conn : sqlite3 .Connection | None = None
81- self ._lock : threading .Lock = threading .Lock ()
87+ self .max_queue_size : int = max_queue_size
88+
89+ self ._queue : _AsyncQueue = EMPTY
90+ self ._worker_task : asyncio .Task [None ] | None = None
8291
8392 self .create_scheduled_table_query : str = (
8493 CREATE_SCHEDULED_TABLE_QUERY .format (table_name )
@@ -93,20 +102,36 @@ def __init__(
93102 table_name ,
94103 )
95104
96- @property
97- def conn (self ) -> sqlite3 .Connection :
98- if self ._conn is None :
99- msg = "Database not initialized. Call startup() first."
100- raise RuntimeError (msg )
101- return self ._conn
105+ async def _worker (self , conn : sqlite3 .Connection ) -> None :
106+ loop = self .getloop ()
107+ while True :
108+ item = await self ._queue .get ()
109+
110+ if item is _STOP :
111+ self ._queue .task_done ()
112+ break
113+
114+ callback , future = item
115+ try :
116+ result = await loop .run_in_executor (
117+ self .threadpool ,
118+ callback ,
119+ conn ,
120+ )
121+ except Exception as exc : # noqa: BLE001
122+ future .set_exception (exc )
123+ else :
124+ future .set_result (result )
125+ finally :
126+ self ._queue .task_done ()
102127
103- async def _to_thread (self , func : Callable [[], ReturnT ]) -> ReturnT :
104- def thread_safe () -> ReturnT :
105- with self ._lock :
106- return func ()
128+ conn .close ()
107129
130+ async def _execute (self , callback : _Callback [ReturnT ]) -> ReturnT :
108131 loop = self .getloop ()
109- return await loop .run_in_executor (self .threadpool , thread_safe )
132+ future : asyncio .Future [ReturnT ] = loop .create_future ()
133+ await self ._queue .put ((callback , future ))
134+ return await future
110135
111136 @override
112137 async def startup (self ) -> None :
@@ -119,19 +144,24 @@ async def startup(self) -> None:
119144 _ = conn .execute ("PRAGMA synchronous=NORMAL;" )
120145 _ = conn .execute (self .create_scheduled_table_query )
121146 conn .commit ()
122- self ._conn = conn
147+ self ._queue = asyncio .Queue (self .max_queue_size )
148+ self ._worker_task = asyncio .create_task (self ._worker (conn ))
123149
124150 @override
125151 async def shutdown (self ) -> None :
126- if self ._conn is not None :
127- with self ._lock :
128- self ._conn .close ()
129- self ._conn = None
152+ if self ._queue is not EMPTY :
153+ await self ._queue .put (_STOP )
154+ await self ._queue .join ()
155+ self ._queue = EMPTY
156+
157+ if self ._worker_task is not None :
158+ await self ._worker_task
159+ self ._worker_task = None
130160
131161 @override
132162 async def get_schedules (self ) -> list [ScheduledJob ]:
133- def get () -> list [ScheduledJob ]:
134- cursor = self . conn .execute (self .select_schedules_query )
163+ def get (conn : sqlite3 . Connection ) -> list [ScheduledJob ]:
164+ cursor = conn .execute (self .select_schedules_query )
135165 return [
136166 ScheduledJob (
137167 job_id = row [0 ],
@@ -145,35 +175,35 @@ def get() -> list[ScheduledJob]:
145175 for row in cursor .fetchall ()
146176 ]
147177
148- return await self ._to_thread (get )
178+ return await self ._execute (get )
149179
150180 @override
151181 async def add_schedule (self , * scheduled : ScheduledJob ) -> None :
152- def insert_many () -> None :
153- with self . conn as conn :
154- _ = conn . executemany (
155- self . insert_schedule_query ,
156- [
157- (
158- sch .job_id ,
159- sch .name ,
160- sch .message ,
161- sch .status ,
162- sch . next_run_at . isoformat (),
163- )
164- for sch in scheduled
165- ],
166- )
182+ def insert_many (conn : sqlite3 . Connection ) -> None :
183+ _ = conn . executemany (
184+ self . insert_schedule_query ,
185+ [
186+ (
187+ sch . job_id ,
188+ sch .name ,
189+ sch .message ,
190+ sch .status ,
191+ sch .next_run_at . isoformat () ,
192+ )
193+ for sch in scheduled
194+ ],
195+ )
196+ conn . commit ( )
167197
168- return await self ._to_thread (insert_many )
198+ return await self ._execute (insert_many )
169199
170200 @override
171201 async def delete_schedule (self , job_id : str ) -> None :
172- def delete () -> None :
173- with self .conn as conn :
174- _ = conn .execute ( self . delete_schedule_query , ( job_id ,) )
202+ def delete (conn : sqlite3 . Connection ) -> None :
203+ _ = conn . execute ( self .delete_schedule_query , ( job_id ,))
204+ conn .commit ( )
175205
176- return await self ._to_thread (delete )
206+ return await self ._execute (delete )
177207
178208 @override
179209 async def delete_schedule_many (self , job_ids : Sequence [str ]) -> None :
@@ -182,17 +212,17 @@ async def delete_schedule_many(self, job_ids: Sequence[str]) -> None:
182212
183213 job_ids = sorted (job_ids )
184214
185- def delete_many () -> None :
215+ def delete_many (conn : sqlite3 . Connection ) -> None :
186216 batch_size = 500
187- with self . conn as conn :
188- for i in range ( 0 , len ( job_ids ), batch_size ):
189- batch = job_ids [ i : i + batch_size ]
190- query = DELETE_SCHEDULE_MANY_QUERY . format_map (
191- {
192- "table_name " : self . table_name ,
193- "placeholder" : "," . join ( "?" * len ( batch )),
194- }
195- )
196- _ = conn .execute ( query , batch )
217+ for i in range ( 0 , len ( job_ids ), batch_size ) :
218+ batch = job_ids [ i : i + batch_size ]
219+ query = DELETE_SCHEDULE_MANY_QUERY . format_map (
220+ {
221+ "table_name" : self . table_name ,
222+ "placeholder " : "," . join ( "?" * len ( batch )) ,
223+ }
224+ )
225+ _ = conn . execute ( query , batch )
226+ conn .commit ( )
197227
198- return await self ._to_thread (delete_many )
228+ return await self ._execute (delete_many )
0 commit comments