Skip to content

Commit 4213afb

Browse files
committed
1. fix tests 2. avoid clone 3. only return a single batch
1 parent c6836d8 commit 4213afb

File tree

2 files changed

+55
-41
lines changed

2 files changed

+55
-41
lines changed

connectorx-python/connectorx/__init__.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,27 @@ def read_sql(
411411
dd = try_import_module("dask.dataframe")
412412
df = dd.from_pandas(df, npartitions=1)
413413

414-
elif return_type in {"arrow", "polars", "arrow_record_batches"}:
414+
elif return_type in {"arrow", "polars"}:
415415
try_import_module("pyarrow")
416416

417+
result = _read_sql(
418+
conn,
419+
"arrow",
420+
queries=queries,
421+
protocol=protocol,
422+
partition_query=partition_query,
423+
pre_execution_queries=pre_execution_queries,
424+
)
425+
426+
df = reconstruct_arrow(result)
427+
if return_type in {"polars"}:
428+
pl = try_import_module("polars")
429+
try:
430+
df = pl.from_arrow(df)
431+
except AttributeError:
432+
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
433+
df = pl.DataFrame.from_arrow(df)
434+
elif return_type in {"arrow_record_batches"}:
417435
record_batch_size = int(kwargs.get("record_batch_size", 10000))
418436
result = _read_sql(
419437
conn,
@@ -425,17 +443,7 @@ def read_sql(
425443
record_batch_size=record_batch_size
426444
)
427445

428-
if return_type == "arrow_record_batches":
429-
df = reconstruct_arrow_rb(result)
430-
else:
431-
df = reconstruct_arrow(result)
432-
if return_type in {"polars"}:
433-
pl = try_import_module("polars")
434-
try:
435-
df = pl.from_arrow(df)
436-
except AttributeError:
437-
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
438-
df = pl.DataFrame.from_arrow(df)
446+
df = reconstruct_arrow_rb(result)
439447
else:
440448
raise ValueError(return_type)
441449

@@ -455,11 +463,10 @@ def reconstruct_arrow_rb(results) -> pa.RecordBatchReader:
455463

456464
def generate_batches(iterator) -> Iterator[pa.RecordBatch]:
457465
for rb_ptrs in iterator:
458-
names, chunk_ptrs_list = rb_ptrs.to_ptrs()
459-
for chunk_ptrs in chunk_ptrs_list:
460-
yield pa.RecordBatch.from_arrays(
461-
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names
462-
)
466+
chunk_ptrs = rb_ptrs.to_ptrs()
467+
yield pa.RecordBatch.from_arrays(
468+
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names
469+
)
463470

464471
return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results))
465472

connectorx-python/src/arrow.rs

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::errors::ConnectorXPythonError;
2+
use anyhow::anyhow;
23
use arrow::record_batch::RecordBatch;
34
use connectorx::source_router::SourceConn;
45
use connectorx::{prelude::*, sql::CXQuery};
@@ -12,7 +13,7 @@ use std::sync::Arc;
1213

1314
/// Python-exposed RecordBatch wrapper
1415
#[pyclass]
15-
pub struct PyRecordBatch(RecordBatch);
16+
pub struct PyRecordBatch(Option<RecordBatch>);
1617

1718
/// Python-exposed iterator over RecordBatches
1819
#[pyclass(module = "connectorx")]
@@ -21,20 +22,22 @@ pub struct PyRecordBatchIterator(Box<dyn RecordBatchIterator>);
2122
#[pymethods]
2223
impl PyRecordBatch {
2324
pub fn num_rows(&self) -> usize {
24-
self.0.num_rows()
25+
self.0.as_ref().map_or(0, |rb| rb.num_rows())
2526
}
2627

2728
pub fn num_columns(&self) -> usize {
28-
self.0.num_columns()
29+
self.0.as_ref().map_or(0, |rb| rb.num_columns())
2930
}
3031

3132
#[throws(ConnectorXPythonError)]
32-
pub fn to_ptrs<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
33+
pub fn to_ptrs<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyAny> {
34+
// Convert the RecordBatch to a vector of pointers, once the RecordBatch is taken, it cannot be reached again.
35+
let rb = self
36+
.0
37+
.take()
38+
.ok_or_else(|| anyhow!("RecordBatch is None, cannot convert to pointers"))?;
3339
let ptrs = py.allow_threads(
34-
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
35-
let rbs = vec![self.0.clone()];
36-
Ok(to_ptrs(rbs))
37-
},
40+
|| -> Result<Vec<(uintptr_t, uintptr_t)>, ConnectorXPythonError> { Ok(to_ptrs_rb(rb)) },
3841
)?;
3942
let obj: PyObject = ptrs.into_py(py);
4043
obj.into_bound(py)
@@ -65,7 +68,7 @@ impl PyRecordBatchIterator {
6568
) -> PyResult<Option<Py<PyRecordBatch>>> {
6669
match slf.0.next_batch() {
6770
Some(rb) => {
68-
let wrapped = PyRecordBatch(rb);
71+
let wrapped = PyRecordBatch(Some(rb));
6972
let py_obj = Py::new(py, wrapped)?;
7073
Ok(Some(py_obj))
7174
}
@@ -118,6 +121,24 @@ pub fn get_arrow_rb_iter<'py>(
118121
obj.into_bound(py)
119122
}
120123

124+
pub fn to_ptrs_rb(rb: RecordBatch) -> Vec<(uintptr_t, uintptr_t)> {
125+
let mut cols = vec![];
126+
127+
for array in rb.columns().into_iter() {
128+
let data = array.to_data();
129+
let array_ptr = Arc::new(arrow::ffi::FFI_ArrowArray::new(&data));
130+
let schema_ptr = Arc::new(
131+
arrow::ffi::FFI_ArrowSchema::try_from(data.data_type()).expect("export schema c"),
132+
);
133+
cols.push((
134+
Arc::into_raw(array_ptr) as uintptr_t,
135+
Arc::into_raw(schema_ptr) as uintptr_t,
136+
));
137+
}
138+
139+
cols
140+
}
141+
121142
pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>) {
122143
if rbs.is_empty() {
123144
return (vec![], vec![]);
@@ -132,21 +153,7 @@ pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintp
132153
.collect();
133154

134155
for rb in rbs.into_iter() {
135-
let mut cols = vec![];
136-
137-
for array in rb.columns().into_iter() {
138-
let data = array.to_data();
139-
let array_ptr = Arc::new(arrow::ffi::FFI_ArrowArray::new(&data));
140-
let schema_ptr = Arc::new(
141-
arrow::ffi::FFI_ArrowSchema::try_from(data.data_type()).expect("export schema c"),
142-
);
143-
cols.push((
144-
Arc::into_raw(array_ptr) as uintptr_t,
145-
Arc::into_raw(schema_ptr) as uintptr_t,
146-
));
147-
}
148-
149-
result.push(cols);
156+
result.push(to_ptrs_rb(rb));
150157
}
151158
(names, result)
152159
}

0 commit comments

Comments
 (0)