77 once_cell:: unsync:: OnceCell ,
88 pyo3:: {
99 exceptions:: PyAssertionError ,
10- types:: { PyBytes , PyMapping , PyModule } ,
10+ types:: { PyBytes , PyList , PyMapping , PyModule } ,
1111 Py , PyAny , PyErr , PyObject , PyResult , Python , ToPyObject ,
1212 } ,
1313 spin_sdk:: {
1414 config,
1515 http:: { Request , Response } ,
1616 key_value, outbound_http,
1717 redis:: { self , RedisParameter , RedisResult } ,
18+ sqlite,
1819 } ,
1920 std:: { collections:: HashMap , env, ops:: Deref , str, sync:: Arc } ,
2021} ;
@@ -125,6 +126,80 @@ impl Store {
125126 }
126127}
127128
129+ #[ derive( Clone ) ]
130+ #[ pyo3:: pyclass]
131+ #[ pyo3( name = "SqliteConnection" ) ]
132+ struct SqliteConnection {
133+ inner : Arc < sqlite:: Connection > ,
134+ }
135+
136+ #[ pyo3:: pymethods]
137+ impl SqliteConnection {
138+ fn execute (
139+ & self ,
140+ _py : Python < ' _ > ,
141+ query : String ,
142+ parameters : Vec < & PyAny > ,
143+ ) -> PyResult < QueryResult > {
144+ let parameters = parameters
145+ . iter ( )
146+ . map ( |v| {
147+ if let Ok ( v) = v. extract :: < i64 > ( ) {
148+ Ok ( sqlite:: ValueParam :: Integer ( v) )
149+ } else if let Ok ( v) = v. extract :: < f64 > ( ) {
150+ Ok ( sqlite:: ValueParam :: Real ( v) )
151+ } else if let Ok ( v) = v. extract :: < & str > ( ) {
152+ Ok ( sqlite:: ValueParam :: Text ( v) )
153+ } else if v. is_none ( ) {
154+ Ok ( sqlite:: ValueParam :: Null )
155+ } else if let Ok ( v) = v. downcast :: < PyBytes > ( ) {
156+ Ok ( sqlite:: ValueParam :: Blob ( v. as_bytes ( ) ) )
157+ } else {
158+ Err ( PyErr :: from ( Anyhow ( anyhow ! (
159+ "Unable to use {v:?} as a SQLite `execute` parameter \
160+ -- expected `int`, `float`, `bytes`, `string`, or `None`"
161+ ) ) ) )
162+ }
163+ } )
164+ . collect :: < PyResult < Vec < _ > > > ( ) ?;
165+ let result = self
166+ . inner
167+ . execute ( & query, & parameters)
168+ . map_err ( Anyhow :: from) ?;
169+ Ok ( QueryResult { inner : result } )
170+ }
171+ }
172+
173+ #[ derive( Clone ) ]
174+ #[ pyo3:: pyclass]
175+ #[ pyo3( name = "QueryResult" ) ]
176+ struct QueryResult {
177+ inner : sqlite:: QueryResult ,
178+ }
179+
180+ #[ pyo3:: pymethods]
181+ impl QueryResult {
182+ fn rows ( & self , py : Python < ' _ > ) -> PyResult < PyObject > {
183+ let rows = self . inner . rows . iter ( ) . map ( |r| {
184+ PyList :: new (
185+ py,
186+ r. values . iter ( ) . map ( |v| match v {
187+ sqlite:: ValueResult :: Integer ( i) => i. to_object ( py) ,
188+ sqlite:: ValueResult :: Real ( r) => r. to_object ( py) ,
189+ sqlite:: ValueResult :: Text ( s) => s. to_object ( py) ,
190+ sqlite:: ValueResult :: Blob ( b) => b. to_object ( py) ,
191+ sqlite:: ValueResult :: Null => py. None ( ) ,
192+ } ) ,
193+ )
194+ } ) ;
195+ Ok ( PyList :: new ( py, rows) . into ( ) )
196+ }
197+
198+ fn columns ( & self , py : Python < ' _ > ) -> PyResult < PyObject > {
199+ Ok ( PyList :: new ( py, self . inner . columns . iter ( ) ) . into ( ) )
200+ }
201+ }
202+
128203struct Anyhow ( Error ) ;
129204
130205impl From < Anyhow > for PyErr {
@@ -325,6 +400,27 @@ fn spin_key_value_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> {
325400 module. add_function ( pyo3:: wrap_pyfunction!( kv_open_default, module) ?)
326401}
327402
403+ #[ pyo3:: pyfunction]
404+ fn sqlite_open ( database : String ) -> Result < SqliteConnection , Anyhow > {
405+ Ok ( SqliteConnection {
406+ inner : Arc :: new ( sqlite:: Connection :: open ( & database) . map_err ( Anyhow :: from) ?) ,
407+ } )
408+ }
409+
410+ #[ pyo3:: pyfunction]
411+ fn sqlite_open_default ( ) -> Result < SqliteConnection , Anyhow > {
412+ Ok ( SqliteConnection {
413+ inner : Arc :: new ( sqlite:: Connection :: open_default ( ) . map_err ( Anyhow :: from) ?) ,
414+ } )
415+ }
416+
417+ #[ pyo3:: pymodule]
418+ #[ pyo3( name = "spin_sqlite" ) ]
419+ fn spin_sqlite_module ( _py : Python < ' _ > , module : & PyModule ) -> PyResult < ( ) > {
420+ module. add_function ( pyo3:: wrap_pyfunction!( sqlite_open, module) ?) ?;
421+ module. add_function ( pyo3:: wrap_pyfunction!( sqlite_open_default, module) ?)
422+ }
423+
328424#[ pyo3:: pyfunction]
329425fn config_get ( key : String ) -> Result < String , Anyhow > {
330426 config:: get ( & key) . map_err ( Anyhow :: from)
@@ -341,6 +437,7 @@ fn do_init() -> Result<()> {
341437 pyo3:: append_to_inittab!( spin_redis_module) ;
342438 pyo3:: append_to_inittab!( spin_config_module) ;
343439 pyo3:: append_to_inittab!( spin_key_value_module) ;
440+ pyo3:: append_to_inittab!( spin_sqlite_module) ;
344441
345442 pyo3:: prepare_freethreaded_python ( ) ;
346443
0 commit comments