Skip to content

Commit 8186c39

Browse files
authored
refactor: simplify extension handling in duckdb (#105)
cc @ceholden just FYSA since you helped work on this, I'm refactoring to Do Less™ here, but expose the ability to just execute any DuckDB query for custom setup, etc. I also provide a way to pass in extension files in init, which we can hopefully use to point to pre-fetched stuff in our lambda.
1 parent b907143 commit 8186c39

File tree

8 files changed

+117
-57
lines changed

8 files changed

+117
-57
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,5 @@ cython_debug/
164164
#.idea/
165165
.vscode/
166166
docs/generated
167+
168+
tests/duckdb-extensions

Cargo.lock

Lines changed: 13 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ duckdb-bundled = ["stac-duckdb/bundled"]
1313

1414
[dependencies]
1515
clap = "4.5.31"
16+
duckdb = { version = "1.2.2", features = ["serde_json"] }
1617
geoarrow-array = { git = "https://github.com/geoarrow/geoarrow-rs/", rev = "17bf33e4cf78b060afa08ca9560dc4efd73c2c76" }
1718
geojson = "0.24.1"
1819
pyo3 = { version = "0.24.1", features = ["extension-module"] }

python/rustac/rustac.pyi

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The power of Rust for the Python STAC ecosystem."""
22

3+
from pathlib import Path
34
from typing import Any, AsyncIterator, Literal, Optional, Tuple
45

56
import arro3.core
@@ -15,28 +16,28 @@ class DuckdbClient:
1516
def __init__(
1617
self,
1718
*,
18-
use_s3_credential_chain: bool = True,
19-
use_azure_credential_chain: bool = True,
20-
use_httpfs: bool = True,
19+
extension_directory: Path | None = None,
20+
extensions: list[str] | None = None,
21+
install_spatial: bool = True,
2122
use_hive_partitioning: bool = False,
22-
install_extensions: bool = True,
23-
custom_extension_repository: str | None = None,
24-
extension_directory: str | None = None,
2523
) -> None:
2624
"""Creates a new duckdb client.
2725
2826
Args:
29-
use_s3_credential_chain: If true, configures DuckDB to correctly
30-
handle s3:// urls.
31-
use_azure_credential_chain: If true, configures DuckDB to correctly
32-
handle azure urls.
33-
use_httpfs: If true, configures DuckDB to correctly handle https
34-
urls.
35-
use_hive_partitioning: If true, enables queries on hive partitioned
36-
geoparquet files.
37-
install_extensions: If true, installs extensions before loading them.
38-
custom_extension_repository: A custom extension repository to use.
3927
extension_directory: A non-standard extension directory to use.
28+
extensions: A list of extensions to LOAD on client initialization.
29+
install_spatial: Whether to install the spatial extension on client initialization.
30+
use_hive_partitioning: Whether to use hive partitioning for geoparquet queries.
31+
"""
32+
33+
def execute(self, sql: str, params: list[str] | None = None) -> int:
34+
"""Execute an SQL command.
35+
36+
This can be useful for configuring AWS credentials, for example.
37+
38+
Args:
39+
sql: The SQL to execute
40+
params: The parameters to pass in to the execution
4041
"""
4142

4243
def search(

scripts/test

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@
22

33
set -e
44

5-
uv run maturin dev --uv -E arrow
65
uv run pytest "$@"
76
uv run rustac translate spec-examples/v1.1.0/simple-item.json /dev/null

src/duckdb.rs

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,59 @@ use crate::{
22
search::{PySortby, StringOrDict, StringOrList},
33
Result,
44
};
5+
use duckdb::Connection;
56
use pyo3::{
67
exceptions::PyException,
78
prelude::*,
89
types::{PyDict, PyList},
910
IntoPyObjectExt,
1011
};
1112
use pyo3_arrow::PyTable;
12-
use stac_duckdb::{Client, Config};
13-
use std::sync::Mutex;
13+
use stac_duckdb::Client;
14+
use std::{path::PathBuf, sync::Mutex};
1415

1516
#[pyclass(frozen)]
1617
pub struct DuckdbClient(Mutex<Client>);
1718

1819
#[pymethods]
1920
impl DuckdbClient {
2021
#[new]
21-
#[pyo3(signature = (*, use_s3_credential_chain=false, use_azure_credential_chain=false, use_httpfs=false, use_hive_partitioning=false, install_extensions=true, custom_extension_repository=None, extension_directory=None))]
22+
#[pyo3(signature = (*, extension_directory=None, extensions=Vec::new(), install_spatial=true, use_hive_partitioning=false))]
2223
fn new(
23-
use_s3_credential_chain: bool,
24-
use_azure_credential_chain: bool,
25-
use_httpfs: bool,
24+
extension_directory: Option<PathBuf>,
25+
extensions: Vec<String>,
26+
install_spatial: bool,
2627
use_hive_partitioning: bool,
27-
install_extensions: bool,
28-
custom_extension_repository: Option<String>,
29-
extension_directory: Option<String>,
3028
) -> Result<DuckdbClient> {
31-
let config = Config {
32-
use_s3_credential_chain,
33-
use_azure_credential_chain,
34-
use_httpfs,
35-
use_hive_partitioning,
36-
install_extensions,
37-
custom_extension_repository,
38-
extension_directory,
39-
convert_wkb: true,
40-
};
41-
let client = Client::with_config(config)?;
29+
let connection = Connection::open_in_memory()?;
30+
if let Some(extension_directory) = extension_directory {
31+
connection.execute(
32+
"SET extension_directory = ?",
33+
[extension_directory.to_string_lossy()],
34+
)?;
35+
}
36+
if install_spatial {
37+
connection.execute("INSTALL spatial", [])?;
38+
}
39+
for extension in extensions {
40+
connection.execute(&format!("LOAD '{}'", extension), [])?;
41+
}
42+
connection.execute("LOAD spatial", [])?;
43+
let mut client = Client::from(connection);
44+
client.use_hive_partitioning = use_hive_partitioning;
4245
Ok(DuckdbClient(Mutex::new(client)))
4346
}
4447

48+
#[pyo3(signature = (sql, params = Vec::new()))]
49+
fn execute<'py>(&self, sql: String, params: Vec<String>) -> Result<usize> {
50+
let client = self
51+
.0
52+
.lock()
53+
.map_err(|err| PyException::new_err(err.to_string()))?;
54+
let count = client.execute(&sql, duckdb::params_from_iter(params))?;
55+
Ok(count)
56+
}
57+
4558
#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, **kwargs))]
4659
fn search<'py>(
4760
&self,
@@ -123,10 +136,11 @@ impl DuckdbClient {
123136
.0
124137
.lock()
125138
.map_err(|err| PyException::new_err(err.to_string()))?;
126-
let convert_wkb = client.config.convert_wkb;
127-
client.config.convert_wkb = false;
139+
// FIXME this is awkward
140+
let convert_wkb = client.convert_wkb;
141+
client.convert_wkb = false;
128142
let result = client.search_to_arrow(&href, search);
129-
client.config.convert_wkb = convert_wkb;
143+
client.convert_wkb = convert_wkb;
130144
result?
131145
};
132146
if record_batches.is_empty() {

src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ create_exception!(rustac, RustacError, PyException);
99

1010
#[derive(Debug, Error)]
1111
pub enum Error {
12+
#[error(transparent)]
13+
Duckdb(#[from] duckdb::Error),
14+
1215
#[error(transparent)]
1316
Geojson(#[from] geojson::Error),
1417

tests/test_duckdb.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
from pathlib import Path
2+
13
import pytest
24
import rustac
35
from geopandas import GeoDataFrame
4-
from rustac import DuckdbClient
6+
from rustac import DuckdbClient, RustacError
57

68

79
@pytest.fixture
810
def client() -> DuckdbClient:
911
return DuckdbClient()
1012

1113

14+
@pytest.fixture
15+
def extension_directory() -> Path:
16+
return Path(__file__).parent / "duckdb-extensions"
17+
18+
1219
def test_search(client: DuckdbClient) -> None:
1320
items = client.search("data/extended-item.parquet")
1421
assert len(items) == 1
@@ -27,11 +34,6 @@ def test_get_collections(client: DuckdbClient) -> None:
2734
assert len(collections) == 1
2835

2936

30-
@pytest.mark.skip("slow")
31-
def test_init_with_config() -> None:
32-
DuckdbClient(use_s3_credential_chain=True, use_hive_partitioning=True)
33-
34-
3537
def test_search_to_arrow(client: DuckdbClient) -> None:
3638
pytest.importorskip("arro3.core")
3739
table = client.search_to_arrow("data/100-sentinel-2-items.parquet")
@@ -40,3 +42,39 @@ def test_search_to_arrow(client: DuckdbClient) -> None:
4042
data_frame_table = data_frame.to_arrow()
4143
item_collection = rustac.from_arrow(data_frame_table)
4244
assert len(item_collection["features"]) == 100
45+
46+
47+
def test_custom_extension_directory(extension_directory: Path) -> None:
48+
client = DuckdbClient(extension_directory=extension_directory)
49+
# Search to ensure we trigger everything
50+
client.search("data/100-sentinel-2-items.parquet")
51+
52+
53+
def test_no_install(tmp_path: Path) -> None:
54+
with pytest.raises(RustacError):
55+
DuckdbClient(extension_directory=tmp_path, install_spatial=False)
56+
57+
58+
def test_extensions(extension_directory: Path, tmp_path: Path) -> None:
59+
# Ensure we've fetched the extension
60+
DuckdbClient(extension_directory=extension_directory)
61+
62+
extension = next(extension_directory.glob("**/spatial.duckdb_extension"))
63+
client = DuckdbClient(
64+
extensions=[str(extension)], extension_directory=tmp_path, install_spatial=False
65+
)
66+
client.search("data/100-sentinel-2-items.parquet")
67+
68+
69+
def test_execute(client: DuckdbClient, extension_directory: Path) -> None:
70+
# Just a smoke test
71+
client.execute("SET extension_directory = ?", [str(extension_directory)])
72+
73+
74+
def test_load_spatial() -> None:
75+
DuckdbClient(extensions=["spatial"])
76+
77+
78+
@pytest.mark.skip("slow")
79+
def test_aws_credential_chain(client: DuckdbClient) -> None:
80+
client.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)")

0 commit comments

Comments
 (0)