22
33import importlib
44import urllib .parse
5-
5+ from collections . abc import Iterator
66from importlib .metadata import version
77from pathlib import Path
88from typing import Literal , TYPE_CHECKING , overload , Generic , TypeVar
@@ -177,6 +177,7 @@ def read_sql(
177177 partition_num : int | None = None ,
178178 index_col : str | None = None ,
179179 pre_execution_query : list [str ] | str | None = None ,
180+ ** kwargs
180181) -> pd .DataFrame : ...
181182
182183
@@ -192,6 +193,7 @@ def read_sql(
192193 partition_num : int | None = None ,
193194 index_col : str | None = None ,
194195 pre_execution_query : list [str ] | str | None = None ,
196+ ** kwargs
195197) -> pd .DataFrame : ...
196198
197199
@@ -207,6 +209,7 @@ def read_sql(
207209 partition_num : int | None = None ,
208210 index_col : str | None = None ,
209211 pre_execution_query : list [str ] | str | None = None ,
212+ ** kwargs
210213) -> pa .Table : ...
211214
212215
@@ -222,6 +225,7 @@ def read_sql(
222225 partition_num : int | None = None ,
223226 index_col : str | None = None ,
224227 pre_execution_query : list [str ] | str | None = None ,
228+ ** kwargs
225229) -> mpd .DataFrame : ...
226230
227231
@@ -237,6 +241,7 @@ def read_sql(
237241 partition_num : int | None = None ,
238242 index_col : str | None = None ,
239243 pre_execution_query : list [str ] | str | None = None ,
244+ ** kwargs
240245) -> dd .DataFrame : ...
241246
242247
@@ -252,6 +257,7 @@ def read_sql(
252257 partition_num : int | None = None ,
253258 index_col : str | None = None ,
254259 pre_execution_query : list [str ] | str | None = None ,
260+ ** kwargs
255261) -> pl .DataFrame : ...
256262
257263
@@ -260,7 +266,7 @@ def read_sql(
260266 query : list [str ] | str ,
261267 * ,
262268 return_type : Literal [
263- "pandas" , "polars" , "arrow" , "modin" , "dask"
269+ "pandas" , "polars" , "arrow" , "modin" , "dask" , "arrow_stream"
264270 ] = "pandas" ,
265271 protocol : Protocol | None = None ,
266272 partition_on : str | None = None ,
@@ -269,18 +275,20 @@ def read_sql(
269275 index_col : str | None = None ,
270276 strategy : str | None = None ,
271277 pre_execution_query : list [str ] | str | None = None ,
272- ) -> pd .DataFrame | mpd .DataFrame | dd .DataFrame | pl .DataFrame | pa .Table :
278+ ** kwargs
279+
280+ ) -> pd .DataFrame | mpd .DataFrame | dd .DataFrame | pl .DataFrame | pa .Table | pa .RecordBatchReader :
273281 """
274282 Run the SQL query, download the data from database into a dataframe.
275283
276284 Parameters
277285 ==========
278286 conn
279- the connection string, or dict of connection string mapping for federated query.
287+ the connection string, or dict of connection string mapping for a federated query.
280288 query
281289 a SQL query or a list of SQL queries.
282290 return_type
283- the return type of this function; one of "arrow(2) ", "pandas", "modin", "dask" or "polars(2) ".
291+ the return type of this function; one of "arrow", "arrow_stream", " pandas", "modin", "dask" or "polars".
284292 protocol
285293 backend-specific transfer protocol directive; defaults to 'binary' (except for redshift
286294 connection strings, where 'cursor' will be used instead).
@@ -293,10 +301,12 @@ def read_sql(
293301 index_col
294302 the index column to set; only applicable for return type "pandas", "modin", "dask".
295303 strategy
296- strategy of rewriting the federated query for join pushdown
304+ strategy of rewriting the federated query for join pushdown.
297305 pre_execution_query
298306 SQL query or list of SQL queries executed before main query; can be used to set runtime
299307 configurations using SET statements; only applicable for source "Postgres" and "MySQL".
308+ batch_size
309+ the maximum size of each batch when return type is `arrow_stream`.
300310
301311 Examples
302312 ========
@@ -414,6 +424,7 @@ def read_sql(
414424 partition_query = partition_query ,
415425 pre_execution_queries = pre_execution_queries ,
416426 )
427+
417428 df = reconstruct_arrow (result )
418429 if return_type in {"polars" }:
419430 pl = try_import_module ("polars" )
@@ -422,12 +433,46 @@ def read_sql(
422433 except AttributeError :
423434 # previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
424435 df = pl .DataFrame .from_arrow (df )
436+ elif return_type in {"arrow_stream" }:
437+ batch_size = int (kwargs .get ("batch_size" , 10000 ))
438+ result = _read_sql (
439+ conn ,
440+ "arrow_stream" ,
441+ queries = queries ,
442+ protocol = protocol ,
443+ partition_query = partition_query ,
444+ pre_execution_queries = pre_execution_queries ,
445+ batch_size = batch_size
446+ )
447+
448+ df = reconstruct_arrow_rb (result )
425449 else :
426450 raise ValueError (return_type )
427451
428452 return df
429453
430454
455+ def reconstruct_arrow_rb (results ) -> pa .RecordBatchReader :
456+ import pyarrow as pa
457+
458+ # Get Schema
459+ names , chunk_ptrs_list = results .schema_ptr ()
460+ for chunk_ptrs in chunk_ptrs_list :
461+ arrays = [pa .Array ._import_from_c (* col_ptr ) for col_ptr in chunk_ptrs ]
462+ empty_rb = pa .RecordBatch .from_arrays (arrays , names )
463+
464+ schema = empty_rb .schema
465+
466+ def generate_batches (iterator ) -> Iterator [pa .RecordBatch ]:
467+ for rb_ptrs in iterator :
468+ chunk_ptrs = rb_ptrs .to_ptrs ()
469+ yield pa .RecordBatch .from_arrays (
470+ [pa .Array ._import_from_c (* col_ptr ) for col_ptr in chunk_ptrs ], names
471+ )
472+
473+ return pa .RecordBatchReader .from_batches (schema = schema , batches = generate_batches (results ))
474+
475+
431476def reconstruct_arrow (result : _ArrowInfos ) -> pa .Table :
432477 import pyarrow as pa
433478
0 commit comments