Skip to content

Commit 016adc0

Browse files
committed
Allow record batches
1 parent c133134 commit 016adc0

File tree

7 files changed

+171
-19
lines changed

7 files changed

+171
-19
lines changed

connectorx-python/connectorx/__init__.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import importlib
44
import urllib.parse
5-
5+
from collections.abc import Iterator
66
from importlib.metadata import version
77
from pathlib import Path
88
from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar
@@ -177,6 +177,7 @@ def read_sql(
177177
partition_num: int | None = None,
178178
index_col: str | None = None,
179179
pre_execution_query: list[str] | str | None = None,
180+
**kwargs
180181
) -> pd.DataFrame: ...
181182

182183

@@ -192,6 +193,7 @@ def read_sql(
192193
partition_num: int | None = None,
193194
index_col: str | None = None,
194195
pre_execution_query: list[str] | str | None = None,
196+
**kwargs
195197
) -> pd.DataFrame: ...
196198

197199

@@ -207,6 +209,7 @@ def read_sql(
207209
partition_num: int | None = None,
208210
index_col: str | None = None,
209211
pre_execution_query: list[str] | str | None = None,
212+
**kwargs
210213
) -> pa.Table: ...
211214

212215

@@ -222,6 +225,7 @@ def read_sql(
222225
partition_num: int | None = None,
223226
index_col: str | None = None,
224227
pre_execution_query: list[str] | str | None = None,
228+
**kwargs
225229
) -> mpd.DataFrame: ...
226230

227231

@@ -237,6 +241,7 @@ def read_sql(
237241
partition_num: int | None = None,
238242
index_col: str | None = None,
239243
pre_execution_query: list[str] | str | None = None,
244+
**kwargs
240245
) -> dd.DataFrame: ...
241246

242247

@@ -252,6 +257,7 @@ def read_sql(
252257
partition_num: int | None = None,
253258
index_col: str | None = None,
254259
pre_execution_query: list[str] | str | None = None,
260+
**kwargs
255261
) -> pl.DataFrame: ...
256262

257263

@@ -260,7 +266,7 @@ def read_sql(
260266
query: list[str] | str,
261267
*,
262268
return_type: Literal[
263-
"pandas", "polars", "arrow", "modin", "dask"
269+
"pandas", "polars", "arrow", "modin", "dask", "arrow_record_batches"
264270
] = "pandas",
265271
protocol: Protocol | None = None,
266272
partition_on: str | None = None,
@@ -269,18 +275,20 @@ def read_sql(
269275
index_col: str | None = None,
270276
strategy: str | None = None,
271277
pre_execution_query: list[str] | str | None = None,
272-
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
278+
**kwargs
279+
280+
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader:
273281
"""
274282
Run the SQL query, download the data from database into a dataframe.
275283
276284
Parameters
277285
==========
278286
conn
279-
the connection string, or dict of connection string mapping for federated query.
287+
the connection string, or dict of connection string mapping for a federated query.
280288
query
281289
a SQL query or a list of SQL queries.
282290
return_type
283-
the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)".
291+
the return type of this function; one of "arrow(2)", "arrow_record_batches", "pandas", "modin", "dask" or "polars(2)".
284292
protocol
285293
backend-specific transfer protocol directive; defaults to 'binary' (except for redshift
286294
connection strings, where 'cursor' will be used instead).
@@ -403,31 +411,59 @@ def read_sql(
403411
dd = try_import_module("dask.dataframe")
404412
df = dd.from_pandas(df, npartitions=1)
405413

406-
elif return_type in {"arrow", "polars"}:
414+
elif return_type in {"arrow", "polars", "arrow_record_batches"}:
407415
try_import_module("pyarrow")
408416

417+
record_batch_size = int(kwargs.get("record_batch_size", 10000))
409418
result = _read_sql(
410419
conn,
411-
"arrow",
420+
"arrow_record_batches",
412421
queries=queries,
413422
protocol=protocol,
414423
partition_query=partition_query,
415424
pre_execution_queries=pre_execution_queries,
425+
record_batch_size=record_batch_size
416426
)
417-
df = reconstruct_arrow(result)
418-
if return_type in {"polars"}:
419-
pl = try_import_module("polars")
420-
try:
421-
df = pl.from_arrow(df)
422-
except AttributeError:
423-
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
424-
df = pl.DataFrame.from_arrow(df)
427+
428+
if return_type == "arrow_record_batches":
429+
df = reconstruct_arrow_rb(result)
430+
else:
431+
df = reconstruct_arrow(result)
432+
if return_type in {"polars"}:
433+
pl = try_import_module("polars")
434+
try:
435+
df = pl.from_arrow(df)
436+
except AttributeError:
437+
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
438+
df = pl.DataFrame.from_arrow(df)
425439
else:
426440
raise ValueError(return_type)
427441

428442
return df
429443

430444

445+
def reconstruct_arrow_rb(results) -> Iterator[pa.RecordBatch]:
446+
import pyarrow as pa
447+
448+
# Get Schema
449+
names, chunk_ptrs_list = results.schema_ptr()
450+
for chunk_ptrs in chunk_ptrs_list:
451+
arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs]
452+
empty_rb = pa.RecordBatch.from_arrays(arrays, names)
453+
454+
schema = empty_rb.schema
455+
456+
def generate_batches(iterator) -> Iterator[pa.RecordBatch]:
457+
for rb_ptrs in iterator:
458+
names, chunk_ptrs_list = rb_ptrs.to_ptrs()
459+
for chunk_ptrs in chunk_ptrs_list:
460+
yield pa.RecordBatch.from_arrays(
461+
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names
462+
)
463+
464+
return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results))
465+
466+
431467
def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
432468
import pyarrow as pa
433469

connectorx-python/connectorx/connectorx.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ def read_sql(
2626
queries: list[str] | None,
2727
partition_query: dict[str, Any] | None,
2828
pre_execution_queries: list[str] | None,
29+
**kwargs
2930
) -> _DataframeInfos: ...
3031
@overload
3132
def read_sql(
3233
conn: str,
33-
return_type: Literal["arrow"],
34+
return_type: Literal["arrow", "arrow_record_batches"],
3435
protocol: str | None,
3536
queries: list[str] | None,
3637
partition_query: dict[str, Any] | None,
3738
pre_execution_queries: list[str] | None,
39+
**kwargs
3840
) -> _ArrowInfos: ...
3941
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
4042
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...

connectorx-python/src/arrow.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,77 @@ use connectorx::{prelude::*, sql::CXQuery};
55
use fehler::throws;
66
use libc::uintptr_t;
77
use pyo3::prelude::*;
8+
use pyo3::pyclass;
89
use pyo3::{PyAny, Python};
910
use std::convert::TryFrom;
1011
use std::sync::Arc;
1112

13+
/// Python-exposed RecordBatch wrapper
14+
#[pyclass]
15+
pub struct PyRecordBatch(RecordBatch);
16+
17+
/// Python-exposed iterator over RecordBatches
18+
#[pyclass(unsendable, module = "connectorx")]
19+
pub struct PyRecordBatchIterator(Box<dyn RecordBatchIterator>);
20+
21+
#[pymethods]
22+
impl PyRecordBatch {
23+
pub fn num_rows(&self) -> usize {
24+
self.0.num_rows()
25+
}
26+
27+
pub fn num_columns(&self) -> usize {
28+
self.0.num_columns()
29+
}
30+
31+
#[throws(ConnectorXPythonError)]
32+
pub fn to_ptrs<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
33+
let ptrs = py.allow_threads(
34+
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
35+
let rbs = vec![self.0.clone()];
36+
Ok(to_ptrs(rbs))
37+
},
38+
)?;
39+
let obj: PyObject = ptrs.into_py(py);
40+
obj.into_bound(py)
41+
}
42+
}
43+
44+
#[pymethods]
45+
impl PyRecordBatchIterator {
46+
47+
#[throws(ConnectorXPythonError)]
48+
fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
49+
let (rb, _) = self.0.get_schema();
50+
let ptrs = py.allow_threads(
51+
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
52+
let rbs = vec![rb];
53+
Ok(to_ptrs(rbs))
54+
},
55+
)?;
56+
let obj: PyObject = ptrs.into_py(py);
57+
obj.into_bound(py)
58+
}
59+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
60+
slf
61+
}
62+
63+
fn __next__<'py>(
64+
mut slf: PyRefMut<'py, Self>,
65+
py: Python<'py>,
66+
) -> PyResult<Option<Py<PyRecordBatch>>> {
67+
match slf.0.next_batch() {
68+
Some(rb) => {
69+
let wrapped = PyRecordBatch(rb);
70+
let py_obj = Py::new(py, wrapped)?;
71+
Ok(Some(py_obj))
72+
}
73+
74+
None => Ok(None),
75+
}
76+
}
77+
}
78+
1279
#[throws(ConnectorXPythonError)]
1380
pub fn write_arrow<'py>(
1481
py: Python<'py>,
@@ -28,6 +95,30 @@ pub fn write_arrow<'py>(
2895
obj.into_bound(py)
2996
}
3097

98+
#[throws(ConnectorXPythonError)]
99+
pub fn get_arrow_rb_iter<'py>(
100+
py: Python<'py>,
101+
source_conn: &SourceConn,
102+
origin_query: Option<String>,
103+
queries: &[CXQuery<String>],
104+
pre_execution_queries: Option<&[String]>,
105+
batch_size: usize,
106+
) -> Bound<'py, PyAny> {
107+
let mut arrow_iter: Box<dyn RecordBatchIterator> = new_record_batch_iter(
108+
source_conn,
109+
origin_query,
110+
queries,
111+
batch_size,
112+
pre_execution_queries,
113+
);
114+
115+
arrow_iter.prepare();
116+
let py_rb_iter = PyRecordBatchIterator(arrow_iter);
117+
118+
let obj: PyObject = py_rb_iter.into_py(py);
119+
obj.into_bound(py)
120+
}
121+
31122
pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>) {
32123
if rbs.is_empty() {
33124
return (vec![], vec![]);

connectorx-python/src/cx_read_sql.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use pyo3::prelude::*;
88
use pyo3::{exceptions::PyValueError, PyResult};
99

1010
use crate::errors::ConnectorXPythonError;
11+
use pyo3::types::PyDict;
1112

1213
#[derive(FromPyObject)]
1314
#[pyo3(from_item_all)]
@@ -39,6 +40,7 @@ pub fn read_sql<'py>(
3940
queries: Option<Vec<String>>,
4041
partition_query: Option<PyPartitionQuery>,
4142
pre_execution_queries: Option<Vec<String>>,
43+
kwargs: Option<&Bound<PyDict>>,
4244
) -> PyResult<Bound<'py, PyAny>> {
4345
let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?;
4446
let (queries, origin_query) = match (queries, partition_query) {
@@ -72,6 +74,22 @@ pub fn read_sql<'py>(
7274
&queries,
7375
pre_execution_queries.as_deref(),
7476
)?),
77+
"arrow_record_batches" => {
78+
let batch_size = kwargs
79+
.and_then(|dict| dict.get_item("record_batch_size").ok().flatten())
80+
.and_then(|obj| obj.extract::<usize>().ok())
81+
.unwrap_or(10000);
82+
83+
Ok(crate::arrow::get_arrow_rb_iter(
84+
py,
85+
&source_conn,
86+
origin_query,
87+
&queries,
88+
pre_execution_queries.as_deref(),
89+
batch_size,
90+
)?)
91+
}
92+
7593
_ => Err(PyValueError::new_err(format!(
7694
"return type should be 'pandas' or 'arrow', got '{}'",
7795
return_type

connectorx-python/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::constants::J4RS_BASE_PATH;
88
use ::connectorx::{fed_dispatcher::run, partition::partition, source_router::parse_source};
99
use pyo3::exceptions::PyRuntimeError;
1010
use pyo3::prelude::*;
11+
use pyo3::types::PyDict;
1112
use pyo3::{wrap_pyfunction, PyResult};
1213
use std::collections::HashMap;
1314
use std::env;
@@ -35,11 +36,13 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
3536
m.add_wrapped(wrap_pyfunction!(partition_sql))?;
3637
m.add_wrapped(wrap_pyfunction!(get_meta))?;
3738
m.add_class::<pandas::PandasBlockInfo>()?;
39+
m.add_class::<arrow::PyRecordBatch>()?;
40+
m.add_class::<arrow::PyRecordBatchIterator>()?;
3841
Ok(())
3942
}
4043

4144
#[pyfunction]
42-
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))]
45+
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None, *, **kwargs))]
4346
pub fn read_sql<'py>(
4447
py: Python<'py>,
4548
conn: &str,
@@ -48,6 +51,7 @@ pub fn read_sql<'py>(
4851
queries: Option<Vec<String>>,
4952
partition_query: Option<cx_read_sql::PyPartitionQuery>,
5053
pre_execution_queries: Option<Vec<String>>,
54+
kwargs: Option<&Bound<PyDict>>,
5155
) -> PyResult<Bound<'py, PyAny>> {
5256
cx_read_sql::read_sql(
5357
py,
@@ -57,6 +61,7 @@ pub fn read_sql<'py>(
5761
queries,
5862
partition_query,
5963
pre_execution_queries,
64+
kwargs,
6065
)
6166
}
6267

connectorx/src/arrow_batch_iter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ where
149149
type Item = RecordBatch;
150150
/// NOTE: not thread safe
151151
fn next(&mut self) -> Option<Self::Item> {
152-
self.dst.record_batch().unwrap()
152+
self.dst.record_batch().ok().flatten()
153153
}
154154
}
155155

connectorx/src/destinations/arrowstream/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ impl ArrowPartitionWriter {
221221
.map(|(builder, &dt)| Realize::<FFinishBuilder>::realize(dt)?(builder))
222222
.collect::<std::result::Result<Vec<_>, crate::errors::ConnectorXError>>()?;
223223
let rb = RecordBatch::try_new(Arc::clone(&self.arrow_schema), columns)?;
224-
self.sender.as_ref().unwrap().send(rb).unwrap();
224+
self.sender.as_ref().and_then(|s| s.send(rb).ok());
225225

226226
self.current_row = 0;
227227
self.current_col = 0;

0 commit comments

Comments
 (0)