@@ -86,6 +86,21 @@ pub type Result<T> = std::result::Result<T, Error>;
8686#[ derive( Debug ) ]
8787pub struct Client {
8888 connection : Connection ,
89+ config : Config ,
90+ }
91+
92+ /// Configuration for a client.
93+ #[ derive( Debug ) ]
94+ pub struct Config {
95+ /// Whether to enable the s3 credential chain, which allows s3:// url access.
96+ ///
97+ /// True by default.
98+ pub use_s3_credential_chain : bool ,
99+
100+ /// Whether to enable hive partitioning.
101+ ///
102+ /// False by default.
103+ pub use_hive_partitioning : bool ,
89104}
90105
91106/// A SQL query.
@@ -109,43 +124,68 @@ impl Client {
109124 /// let client = Client::new().unwrap();
110125 /// ```
111126 pub fn new ( ) -> Result < Client > {
127+ Client :: with_config ( Config :: default ( ) )
128+ }
129+
130+ /// Creates a new client with the provided configuration.
131+ ///
132+ /// # Examples
133+ ///
134+ /// ```
135+ /// use stac_duckdb::{Client, Config};
136+ ///
137+ /// let config = Config {
138+ /// use_s3_credential_chain: true,
139+ /// use_hive_partitioning: true,
140+ /// };
141+ /// let client = Client::with_config(config);
142+ /// ```
143+ pub fn with_config ( config : Config ) -> Result < Client > {
112144 let connection = Connection :: open_in_memory ( ) ?;
113145 connection. execute ( "INSTALL spatial" , [ ] ) ?;
114146 connection. execute ( "LOAD spatial" , [ ] ) ?;
115147 connection. execute ( "INSTALL icu" , [ ] ) ?;
116148 connection. execute ( "LOAD icu" , [ ] ) ?;
117- connection. execute ( "CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)" , [ ] ) ?;
118- Ok ( Client { connection } )
149+ if config. use_s3_credential_chain {
150+ connection. execute ( "CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)" , [ ] ) ?;
151+ }
152+ Ok ( Client { connection, config } )
119153 }
120154
121155 /// Returns one or more [stac::Collection] from the items in the stac-geoparquet file.
122156 pub fn collections ( & self , href : & str ) -> Result < Vec < Collection > > {
123157 let start_datetime= if self . connection . prepare ( & format ! (
124- "SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}') ) where column_name = 'start_datetime'" ,
125- href
158+ "SELECT column_name FROM (DESCRIBE SELECT * from {} ) where column_name = 'start_datetime'" ,
159+ self . read_parquet_str ( href)
126160 ) ) ?. query ( [ ] ) ?. next ( ) ?. is_some ( ) {
127161 "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
128162 } else {
129163 "strftime(min(datetime), '%xT%X%z')"
130164 } ;
131- let end_datetime= if self . connection . prepare ( & format ! (
132- "SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}')) where column_name = 'end_datetime'" ,
133- href
134- ) ) ?. query ( [ ] ) ?. next ( ) ?. is_some ( ) {
165+ let end_datetime = if self
166+ . connection
167+ . prepare ( & format ! (
168+ "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'" ,
169+ self . read_parquet_str( href)
170+ ) ) ?
171+ . query ( [ ] ) ?
172+ . next ( ) ?
173+ . is_some ( )
174+ {
135175 "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
136176 } else {
137177 "strftime(max(datetime), '%xT%X%z')"
138178 } ;
139179 let mut statement = self . connection . prepare ( & format ! (
140- "SELECT DISTINCT collection FROM read_parquet('{}') " ,
141- href
180+ "SELECT DISTINCT collection FROM {} " ,
181+ self . read_parquet_str ( href)
142182 ) ) ?;
143183 let mut collections = Vec :: new ( ) ;
144184 for row in statement. query_map ( [ ] , |row| row. get :: < _ , String > ( 0 ) ) ? {
145185 let collection_id = row?;
146186 let mut statement = self . connection . prepare ( &
147- format ! ( "SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM read_parquet('{}') WHERE collection = $1" , start_datetime, end_datetime,
148- href
187+ format ! ( "SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1" , start_datetime, end_datetime,
188+ self . read_parquet_str ( href)
149189 ) ) ?;
150190 let row = statement. query_row ( [ & collection_id] , |row| {
151191 Ok ( (
@@ -235,8 +275,8 @@ impl Client {
235275 let fields = std:: mem:: take ( & mut search. items . fields ) ;
236276
237277 let mut statement = self . connection . prepare ( & format ! (
238- "SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}') )" ,
239- href
278+ "SELECT column_name FROM (DESCRIBE SELECT * from {} )" ,
279+ self . read_parquet_str ( href)
240280 ) ) ?;
241281 let mut columns = Vec :: new ( ) ;
242282 // Can we use SQL magic to make our query not depend on which columns are present?
@@ -354,14 +394,25 @@ impl Client {
354394 }
355395 Ok ( Query {
356396 sql : format ! (
357- "SELECT {} FROM read_parquet('{}') {}" ,
397+ "SELECT {} FROM {} {}" ,
358398 columns. join( "," ) ,
359- href,
399+ self . read_parquet_str ( href) ,
360400 suffix,
361401 ) ,
362402 params,
363403 } )
364404 }
405+
406+ fn read_parquet_str ( & self , href : & str ) -> String {
407+ if self . config . use_hive_partitioning {
408+ format ! (
409+ "read_parquet('{}', filename=true, hive_partitioning=1)" ,
410+ href
411+ )
412+ } else {
413+ format ! ( "read_parquet('{}', filename=true)" , href)
414+ }
415+ }
365416}
366417
367418/// Return this crate's version.
@@ -396,6 +447,15 @@ fn to_geoarrow_record_batch(mut record_batch: RecordBatch) -> Result<RecordBatch
396447 Ok ( record_batch)
397448}
398449
450+ impl Default for Config {
451+ fn default ( ) -> Self {
452+ Config {
453+ use_hive_partitioning : false ,
454+ use_s3_credential_chain : true ,
455+ }
456+ }
457+ }
458+
399459#[ cfg( test) ]
400460mod tests {
401461 use super :: Client ;
0 commit comments