22
33import importlib
44import urllib .parse
5-
5+ from collections . abc import Iterator
66from importlib .metadata import version
77from pathlib import Path
88from typing import Literal , TYPE_CHECKING , overload , Generic , TypeVar
@@ -177,6 +177,7 @@ def read_sql(
177177 partition_num : int | None = None ,
178178 index_col : str | None = None ,
179179 pre_execution_query : list [str ] | str | None = None ,
180+ ** kwargs
180181) -> pd .DataFrame : ...
181182
182183
@@ -192,6 +193,7 @@ def read_sql(
192193 partition_num : int | None = None ,
193194 index_col : str | None = None ,
194195 pre_execution_query : list [str ] | str | None = None ,
196+ ** kwargs
195197) -> pd .DataFrame : ...
196198
197199
@@ -207,6 +209,7 @@ def read_sql(
207209 partition_num : int | None = None ,
208210 index_col : str | None = None ,
209211 pre_execution_query : list [str ] | str | None = None ,
212+ ** kwargs
210213) -> pa .Table : ...
211214
212215
@@ -222,6 +225,7 @@ def read_sql(
222225 partition_num : int | None = None ,
223226 index_col : str | None = None ,
224227 pre_execution_query : list [str ] | str | None = None ,
228+ ** kwargs
225229) -> mpd .DataFrame : ...
226230
227231
@@ -237,6 +241,7 @@ def read_sql(
237241 partition_num : int | None = None ,
238242 index_col : str | None = None ,
239243 pre_execution_query : list [str ] | str | None = None ,
244+ ** kwargs
240245) -> dd .DataFrame : ...
241246
242247
@@ -252,6 +257,7 @@ def read_sql(
252257 partition_num : int | None = None ,
253258 index_col : str | None = None ,
254259 pre_execution_query : list [str ] | str | None = None ,
260+ ** kwargs
255261) -> pl .DataFrame : ...
256262
257263
@@ -260,7 +266,7 @@ def read_sql(
260266 query : list [str ] | str ,
261267 * ,
262268 return_type : Literal [
263- "pandas" , "polars" , "arrow" , "modin" , "dask"
269+ "pandas" , "polars" , "arrow" , "modin" , "dask" , "arrow_record_batches"
264270 ] = "pandas" ,
265271 protocol : Protocol | None = None ,
266272 partition_on : str | None = None ,
@@ -269,18 +275,20 @@ def read_sql(
269275 index_col : str | None = None ,
270276 strategy : str | None = None ,
271277 pre_execution_query : list [str ] | str | None = None ,
272- ) -> pd .DataFrame | mpd .DataFrame | dd .DataFrame | pl .DataFrame | pa .Table :
278+ ** kwargs
279+
280+ ) -> pd .DataFrame | mpd .DataFrame | dd .DataFrame | pl .DataFrame | pa .Table | pa .RecordBatchReader :
273281 """
274282 Run the SQL query, download the data from database into a dataframe.
275283
276284 Parameters
277285 ==========
278286 conn
279- the connection string, or dict of connection string mapping for federated query.
287+ the connection string, or dict of connection string mapping for a federated query.
280288 query
281289 a SQL query or a list of SQL queries.
282290 return_type
283- the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)".
291+ the return type of this function; one of "arrow(2)", "arrow_record_batches", " pandas", "modin", "dask" or "polars(2)".
284292 protocol
285293 backend-specific transfer protocol directive; defaults to 'binary' (except for redshift
286294 connection strings, where 'cursor' will be used instead).
@@ -403,31 +411,59 @@ def read_sql(
403411 dd = try_import_module ("dask.dataframe" )
404412 df = dd .from_pandas (df , npartitions = 1 )
405413
406- elif return_type in {"arrow" , "polars" }:
414+ elif return_type in {"arrow" , "polars" , "arrow_record_batches" }:
407415 try_import_module ("pyarrow" )
408416
417+ record_batch_size = int (kwargs .get ("record_batch_size" , 10000 ))
409418 result = _read_sql (
410419 conn ,
411- "arrow " ,
420+ "arrow_record_batches " ,
412421 queries = queries ,
413422 protocol = protocol ,
414423 partition_query = partition_query ,
415424 pre_execution_queries = pre_execution_queries ,
425+ record_batch_size = record_batch_size
416426 )
417- df = reconstruct_arrow (result )
418- if return_type in {"polars" }:
419- pl = try_import_module ("polars" )
420- try :
421- df = pl .from_arrow (df )
422- except AttributeError :
423- # previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
424- df = pl .DataFrame .from_arrow (df )
427+
428+ if return_type == "arrow_record_batches" :
429+ df = reconstruct_arrow_rb (result )
430+ else :
431+ df = reconstruct_arrow (result )
432+ if return_type in {"polars" }:
433+ pl = try_import_module ("polars" )
434+ try :
435+ df = pl .from_arrow (df )
436+ except AttributeError :
437+ # previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
438+ df = pl .DataFrame .from_arrow (df )
425439 else :
426440 raise ValueError (return_type )
427441
428442 return df
429443
430444
445+ def reconstruct_arrow_rb (results ) -> Iterator [pa .RecordBatch ]:
446+ import pyarrow as pa
447+
448+ # Get Schema
449+ names , chunk_ptrs_list = results .schema_ptr ()
450+ for chunk_ptrs in chunk_ptrs_list :
451+ arrays = [pa .Array ._import_from_c (* col_ptr ) for col_ptr in chunk_ptrs ]
452+ empty_rb = pa .RecordBatch .from_arrays (arrays , names )
453+
454+ schema = empty_rb .schema
455+
456+ def generate_batches (iterator ) -> Iterator [pa .RecordBatch ]:
457+ for rb_ptrs in iterator :
458+ names , chunk_ptrs_list = rb_ptrs .to_ptrs ()
459+ for chunk_ptrs in chunk_ptrs_list :
460+ yield pa .RecordBatch .from_arrays (
461+ [pa .Array ._import_from_c (* col_ptr ) for col_ptr in chunk_ptrs ], names
462+ )
463+
464+ return pa .RecordBatchReader .from_batches (schema = schema , batches = generate_batches (results ))
465+
466+
431467def reconstruct_arrow (result : _ArrowInfos ) -> pa .Table :
432468 import pyarrow as pa
433469
0 commit comments