@@ -2,46 +2,59 @@ use crate::{
22 search:: { PySortby , StringOrDict , StringOrList } ,
33 Result ,
44} ;
5+ use duckdb:: Connection ;
56use pyo3:: {
67 exceptions:: PyException ,
78 prelude:: * ,
89 types:: { PyDict , PyList } ,
910 IntoPyObjectExt ,
1011} ;
1112use pyo3_arrow:: PyTable ;
12- use stac_duckdb:: { Client , Config } ;
13- use std:: sync:: Mutex ;
13+ use stac_duckdb:: Client ;
14+ use std:: { path :: PathBuf , sync:: Mutex } ;
1415
1516#[ pyclass( frozen) ]
1617pub struct DuckdbClient ( Mutex < Client > ) ;
1718
1819#[ pymethods]
1920impl DuckdbClient {
2021 #[ new]
21- #[ pyo3( signature = ( * , use_s3_credential_chain= false , use_azure_credential_chain= false , use_httpfs= false , use_hive_partitioning=false , install_extensions= true , custom_extension_repository= None , extension_directory= None ) ) ]
22+ #[ pyo3( signature = ( * , extension_directory= None , extensions= Vec :: new ( ) , install_spatial= true , use_hive_partitioning=false ) ) ]
2223 fn new (
23- use_s3_credential_chain : bool ,
24- use_azure_credential_chain : bool ,
25- use_httpfs : bool ,
24+ extension_directory : Option < PathBuf > ,
25+ extensions : Vec < String > ,
26+ install_spatial : bool ,
2627 use_hive_partitioning : bool ,
27- install_extensions : bool ,
28- custom_extension_repository : Option < String > ,
29- extension_directory : Option < String > ,
3028 ) -> Result < DuckdbClient > {
31- let config = Config {
32- use_s3_credential_chain,
33- use_azure_credential_chain,
34- use_httpfs,
35- use_hive_partitioning,
36- install_extensions,
37- custom_extension_repository,
38- extension_directory,
39- convert_wkb : true ,
40- } ;
41- let client = Client :: with_config ( config) ?;
29+ let connection = Connection :: open_in_memory ( ) ?;
30+ if let Some ( extension_directory) = extension_directory {
31+ connection. execute (
32+ "SET extension_directory = ?" ,
33+ [ extension_directory. to_string_lossy ( ) ] ,
34+ ) ?;
35+ }
36+ if install_spatial {
37+ connection. execute ( "INSTALL spatial" , [ ] ) ?;
38+ }
39+ for extension in extensions {
40+ connection. execute ( & format ! ( "LOAD '{}'" , extension) , [ ] ) ?;
41+ }
42+ connection. execute ( "LOAD spatial" , [ ] ) ?;
43+ let mut client = Client :: from ( connection) ;
44+ client. use_hive_partitioning = use_hive_partitioning;
4245 Ok ( DuckdbClient ( Mutex :: new ( client) ) )
4346 }
4447
48+ #[ pyo3( signature = ( sql, params = Vec :: new( ) ) ) ]
49+ fn execute < ' py > ( & self , sql : String , params : Vec < String > ) -> Result < usize > {
50+ let client = self
51+ . 0
52+ . lock ( )
53+ . map_err ( |err| PyException :: new_err ( err. to_string ( ) ) ) ?;
54+ let count = client. execute ( & sql, duckdb:: params_from_iter ( params) ) ?;
55+ Ok ( count)
56+ }
57+
4558 #[ pyo3( signature = ( href, * , intersects=None , ids=None , collections=None , limit=None , bbox=None , datetime=None , include=None , exclude=None , sortby=None , filter=None , query=None , * * kwargs) ) ]
4659 fn search < ' py > (
4760 & self ,
@@ -123,10 +136,11 @@ impl DuckdbClient {
123136 . 0
124137 . lock ( )
125138 . map_err ( |err| PyException :: new_err ( err. to_string ( ) ) ) ?;
126- let convert_wkb = client. config . convert_wkb ;
127- client. config . convert_wkb = false ;
139+ // FIXME this is awkward
140+ let convert_wkb = client. convert_wkb ;
141+ client. convert_wkb = false ;
128142 let result = client. search_to_arrow ( & href, search) ;
129- client. config . convert_wkb = convert_wkb;
143+ client. convert_wkb = convert_wkb;
130144 result?
131145 } ;
132146 if record_batches. is_empty ( ) {
0 commit comments