diff --git a/python/rustac/rustac.pyi b/python/rustac/rustac.pyi index 5871ca5..69ee246 100644 --- a/python/rustac/rustac.pyi +++ b/python/rustac/rustac.pyi @@ -18,7 +18,7 @@ class DuckdbClient: *, extension_directory: Path | None = None, extensions: list[str] | None = None, - install_spatial: bool = True, + install_extensions: bool = True, use_hive_partitioning: bool = False, ) -> None: """Creates a new duckdb client. @@ -26,7 +26,7 @@ class DuckdbClient: Args: extension_directory: A non-standard extension directory to use. extensions: A list of extensions to LOAD on client initialization. - install_spatial: Whether to install the spatial extension on client initialization. + install_extensions: Whether to install the spatial and icu extensions on client initialization. use_hive_partitioning: Whether to use hive partitioning for geoparquet queries. """ diff --git a/src/duckdb.rs b/src/duckdb.rs index 02a3b62..1e77425 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -19,11 +19,11 @@ pub struct DuckdbClient(Mutex); #[pymethods] impl DuckdbClient { #[new] - #[pyo3(signature = (*, extension_directory=None, extensions=Vec::new(), install_spatial=true, use_hive_partitioning=false))] + #[pyo3(signature = (*, extension_directory=None, extensions=Vec::new(), install_extensions=true, use_hive_partitioning=false))] fn new( extension_directory: Option, extensions: Vec, - install_spatial: bool, + install_extensions: bool, use_hive_partitioning: bool, ) -> Result { let connection = Connection::open_in_memory()?; @@ -33,13 +33,15 @@ impl DuckdbClient { [extension_directory.to_string_lossy()], )?; } - if install_spatial { + if install_extensions { connection.execute("INSTALL spatial", [])?; + connection.execute("INSTALL icu", [])?; } for extension in extensions { connection.execute(&format!("LOAD '{}'", extension), [])?; } connection.execute("LOAD spatial", [])?; + connection.execute("LOAD icu", [])?; let mut client = Client::from(connection); client.use_hive_partitioning = use_hive_partitioning; Ok(DuckdbClient(Mutex::new(client))) diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index f163ff4..d7ceb64 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -52,16 +52,18 @@ def test_custom_extension_directory(extension_directory: Path) -> None: def test_no_install(tmp_path: Path) -> None: with pytest.raises(RustacError): - DuckdbClient(extension_directory=tmp_path, install_spatial=False) + DuckdbClient(extension_directory=tmp_path, install_extensions=False) def test_extensions(extension_directory: Path, tmp_path: Path) -> None: # Ensure we've fetched the extension DuckdbClient(extension_directory=extension_directory) - extension = next(extension_directory.glob("**/spatial.duckdb_extension")) + extensions = list(str(e) for e in extension_directory.glob("**/*.duckdb_extension")) client = DuckdbClient( - extensions=[str(extension)], extension_directory=tmp_path, install_spatial=False + extensions=extensions, + extension_directory=tmp_path, + install_extensions=False, ) client.search("data/100-sentinel-2-items.parquet")