Skip to content

Commit 144b9aa

Browse files
committed
feat: add config args to duckdb client
1 parent b1bd12d commit 144b9aa

File tree

4 files changed

+61
-61
lines changed

4 files changed

+61
-61
lines changed

Cargo.lock

Lines changed: 37 additions & 58 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/duckdb.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use pyo3::{
55
types::{PyDict, PyList},
66
};
77
use stac_api::python::{StringOrDict, StringOrList};
8-
use stac_duckdb::Client;
8+
use stac_duckdb::{Client, Config};
99
use std::sync::Mutex;
1010

1111
#[pyclass(frozen)]
@@ -14,8 +14,13 @@ pub struct DuckdbClient(Mutex<Client>);
1414
#[pymethods]
1515
impl DuckdbClient {
1616
#[new]
17-
fn new() -> Result<DuckdbClient> {
18-
let client = Client::new()?;
17+
#[pyo3(signature = (use_s3_credential_chain=true, use_hive_partitioning=false))]
18+
fn new(use_s3_credential_chain: bool, use_hive_partitioning: bool) -> Result<DuckdbClient> {
19+
let config = Config {
20+
use_s3_credential_chain,
21+
use_hive_partitioning,
22+
};
23+
let client = Client::with_config(config)?;
1924
Ok(DuckdbClient(Mutex::new(client)))
2025
}
2126

stacrs.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@ from typing import Any, Optional, Tuple
33
class DuckdbClient:
44
"""A client for querying stac-geoparquet with DuckDB."""
55

6+
def __init__(
7+
self, use_s3_credential_chain: bool = True, use_hive_partitioning: bool = False
8+
) -> None:
9+
"""Creates a new duckdb client.
10+
11+
Args:
12+
use_s3_credential_chain: If true, configures DuckDB to correctly
13+
handle s3:// urls.
14+
use_hive_partitioning: If true, enables queries on hive partitioned
15+
geoparquet files.
16+
"""
17+
618
def search(
719
self,
820
href: str,

tests/test_duckdb.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ def test_search_offset(client: DuckdbClient) -> None:
3333
def test_get_collections(client: DuckdbClient) -> None:
3434
collections = client.get_collections("data/100-sentinel-2-items.parquet")
3535
assert len(collections) == 1
36+
37+
38+
def test_init_with_config() -> None:
39+
DuckdbClient(use_s3_credential_chain=True, use_hive_partitioning=True)

0 commit comments

Comments
 (0)