Skip to content

Commit da319be

Browse files
authored
Merge pull request #819 from chitralverma/allow-record-batches
feat(arrow): Allow record batches output from read_sql
2 parents c133134 + bc69438 commit da319be

File tree

8 files changed

+265
-31
lines changed

8 files changed

+265
-31
lines changed

connectorx-python/connectorx/__init__.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import importlib
44
import urllib.parse
5-
5+
from collections.abc import Iterator
66
from importlib.metadata import version
77
from pathlib import Path
88
from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar
@@ -177,6 +177,7 @@ def read_sql(
177177
partition_num: int | None = None,
178178
index_col: str | None = None,
179179
pre_execution_query: list[str] | str | None = None,
180+
**kwargs
180181
) -> pd.DataFrame: ...
181182

182183

@@ -192,6 +193,7 @@ def read_sql(
192193
partition_num: int | None = None,
193194
index_col: str | None = None,
194195
pre_execution_query: list[str] | str | None = None,
196+
**kwargs
195197
) -> pd.DataFrame: ...
196198

197199

@@ -207,6 +209,7 @@ def read_sql(
207209
partition_num: int | None = None,
208210
index_col: str | None = None,
209211
pre_execution_query: list[str] | str | None = None,
212+
**kwargs
210213
) -> pa.Table: ...
211214

212215

@@ -222,6 +225,7 @@ def read_sql(
222225
partition_num: int | None = None,
223226
index_col: str | None = None,
224227
pre_execution_query: list[str] | str | None = None,
228+
**kwargs
225229
) -> mpd.DataFrame: ...
226230

227231

@@ -237,6 +241,7 @@ def read_sql(
237241
partition_num: int | None = None,
238242
index_col: str | None = None,
239243
pre_execution_query: list[str] | str | None = None,
244+
**kwargs
240245
) -> dd.DataFrame: ...
241246

242247

@@ -252,6 +257,7 @@ def read_sql(
252257
partition_num: int | None = None,
253258
index_col: str | None = None,
254259
pre_execution_query: list[str] | str | None = None,
260+
**kwargs
255261
) -> pl.DataFrame: ...
256262

257263

@@ -260,7 +266,7 @@ def read_sql(
260266
query: list[str] | str,
261267
*,
262268
return_type: Literal[
263-
"pandas", "polars", "arrow", "modin", "dask"
269+
"pandas", "polars", "arrow", "modin", "dask", "arrow_stream"
264270
] = "pandas",
265271
protocol: Protocol | None = None,
266272
partition_on: str | None = None,
@@ -269,18 +275,20 @@ def read_sql(
269275
index_col: str | None = None,
270276
strategy: str | None = None,
271277
pre_execution_query: list[str] | str | None = None,
272-
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
278+
**kwargs
279+
280+
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader:
273281
"""
274282
Run the SQL query, download the data from database into a dataframe.
275283
276284
Parameters
277285
==========
278286
conn
279-
the connection string, or dict of connection string mapping for federated query.
287+
the connection string, or dict of connection string mapping for a federated query.
280288
query
281289
a SQL query or a list of SQL queries.
282290
return_type
283-
the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)".
291+
the return type of this function; one of "arrow", "arrow_stream", "pandas", "modin", "dask" or "polars".
284292
protocol
285293
backend-specific transfer protocol directive; defaults to 'binary' (except for redshift
286294
connection strings, where 'cursor' will be used instead).
@@ -293,10 +301,12 @@ def read_sql(
293301
index_col
294302
the index column to set; only applicable for return type "pandas", "modin", "dask".
295303
strategy
296-
strategy of rewriting the federated query for join pushdown
304+
strategy of rewriting the federated query for join pushdown.
297305
pre_execution_query
298306
SQL query or list of SQL queries executed before main query; can be used to set runtime
299307
configurations using SET statements; only applicable for source "Postgres" and "MySQL".
308+
batch_size
309+
the maximum size of each batch when return type is `arrow_stream`.
300310
301311
Examples
302312
========
@@ -414,6 +424,7 @@ def read_sql(
414424
partition_query=partition_query,
415425
pre_execution_queries=pre_execution_queries,
416426
)
427+
417428
df = reconstruct_arrow(result)
418429
if return_type in {"polars"}:
419430
pl = try_import_module("polars")
@@ -422,12 +433,46 @@ def read_sql(
422433
except AttributeError:
423434
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
424435
df = pl.DataFrame.from_arrow(df)
436+
elif return_type in {"arrow_stream"}:
437+
batch_size = int(kwargs.get("batch_size", 10000))
438+
result = _read_sql(
439+
conn,
440+
"arrow_stream",
441+
queries=queries,
442+
protocol=protocol,
443+
partition_query=partition_query,
444+
pre_execution_queries=pre_execution_queries,
445+
batch_size=batch_size
446+
)
447+
448+
df = reconstruct_arrow_rb(result)
425449
else:
426450
raise ValueError(return_type)
427451

428452
return df
429453

430454

455+
def reconstruct_arrow_rb(results) -> pa.RecordBatchReader:
456+
import pyarrow as pa
457+
458+
# Get Schema
459+
names, chunk_ptrs_list = results.schema_ptr()
460+
for chunk_ptrs in chunk_ptrs_list:
461+
arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs]
462+
empty_rb = pa.RecordBatch.from_arrays(arrays, names)
463+
464+
schema = empty_rb.schema
465+
466+
def generate_batches(iterator) -> Iterator[pa.RecordBatch]:
467+
for rb_ptrs in iterator:
468+
chunk_ptrs = rb_ptrs.to_ptrs()
469+
yield pa.RecordBatch.from_arrays(
470+
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names
471+
)
472+
473+
return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results))
474+
475+
431476
def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
432477
import pyarrow as pa
433478

connectorx-python/connectorx/connectorx.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ def read_sql(
2626
queries: list[str] | None,
2727
partition_query: dict[str, Any] | None,
2828
pre_execution_queries: list[str] | None,
29+
**kwargs
2930
) -> _DataframeInfos: ...
3031
@overload
3132
def read_sql(
3233
conn: str,
33-
return_type: Literal["arrow"],
34+
return_type: Literal["arrow", "arrow_stream"],
3435
protocol: str | None,
3536
queries: list[str] | None,
3637
partition_query: dict[str, Any] | None,
3738
pre_execution_queries: list[str] | None,
39+
**kwargs
3840
) -> _ArrowInfos: ...
3941
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
4042
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...

connectorx-python/connectorx/tests/test_arrow.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,73 @@ def test_arrow(postgres_url: str) -> None:
4444
df.sort_values(by="test_int", inplace=True, ignore_index=True)
4545
assert_frame_equal(df, expected, check_names=True)
4646

47+
def test_arrow_stream(postgres_url: str) -> None:
48+
import pyarrow as pa
49+
query = "SELECT * FROM test_table"
50+
reader = read_sql(
51+
postgres_url,
52+
query,
53+
return_type="arrow_stream",
54+
batch_size=2,
55+
)
56+
batches = []
57+
for batch in reader:
58+
batches.append(batch)
59+
table = pa.Table.from_batches(batches)
60+
df = table.to_pandas()
61+
df.sort_values(by="test_int", inplace=True, ignore_index=True)
62+
63+
expected = pd.DataFrame(
64+
index=range(6),
65+
data={
66+
"test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"),
67+
"test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"),
68+
"test_str": pd.Series(
69+
["a", "str1", "str2", "b", "c", None], dtype="object"
70+
),
71+
"test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"),
72+
"test_bool": pd.Series(
73+
[None, True, False, False, None, True], dtype="object"
74+
),
75+
},
76+
)
77+
assert_frame_equal(df, expected, check_names=True)
78+
79+
def test_arrow_stream_with_partition(postgres_url: str) -> None:
80+
import pyarrow as pa
81+
query = "SELECT * FROM test_table"
82+
reader = read_sql(
83+
postgres_url,
84+
query,
85+
partition_on="test_int",
86+
partition_range=(0, 2000),
87+
partition_num=3,
88+
return_type="arrow_stream",
89+
batch_size=2,
90+
)
91+
batches = []
92+
for batch in reader:
93+
batches.append(batch)
94+
table = pa.Table.from_batches(batches)
95+
df = table.to_pandas()
96+
df.sort_values(by="test_int", inplace=True, ignore_index=True)
97+
98+
expected = pd.DataFrame(
99+
index=range(6),
100+
data={
101+
"test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"),
102+
"test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"),
103+
"test_str": pd.Series(
104+
["a", "str1", "str2", "b", "c", None], dtype="object"
105+
),
106+
"test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"),
107+
"test_bool": pd.Series(
108+
[None, True, False, False, None, True], dtype="object"
109+
),
110+
},
111+
)
112+
assert_frame_equal(df, expected, check_names=True)
113+
47114
def decimal_s10(val):
48115
return Decimal(val).quantize(Decimal("0.0000000001"))
49116

0 commit comments

Comments
 (0)