@@ -44,6 +44,73 @@ def test_arrow(postgres_url: str) -> None:
4444 df .sort_values (by = "test_int" , inplace = True , ignore_index = True )
4545 assert_frame_equal (df , expected , check_names = True )
4646
47+ def test_arrow_stream (postgres_url : str ) -> None :
48+ import pyarrow as pa
49+ query = "SELECT * FROM test_table"
50+ reader = read_sql (
51+ postgres_url ,
52+ query ,
53+ return_type = "arrow_record_batches" ,
54+ record_batch_size = 2 ,
55+ )
56+ batches = []
57+ for batch in reader :
58+ batches .append (batch )
59+ table = pa .Table .from_batches (batches )
60+ df = table .to_pandas ()
61+ df .sort_values (by = "test_int" , inplace = True , ignore_index = True )
62+
63+ expected = pd .DataFrame (
64+ index = range (6 ),
65+ data = {
66+ "test_int" : pd .Series ([0 , 1 , 2 , 3 , 4 , 1314 ], dtype = "int64" ),
67+ "test_nullint" : pd .Series ([5 , 3 , None , 7 , 9 , 2 ], dtype = "float64" ),
68+ "test_str" : pd .Series (
69+ ["a" , "str1" , "str2" , "b" , "c" , None ], dtype = "object"
70+ ),
71+ "test_float" : pd .Series ([3.1 , None , 2.2 , 3 , 7.8 , - 10 ], dtype = "float64" ),
72+ "test_bool" : pd .Series (
73+ [None , True , False , False , None , True ], dtype = "object"
74+ ),
75+ },
76+ )
77+ assert_frame_equal (df , expected , check_names = True )
78+
79+ def test_arrow_stream_with_partition (postgres_url : str ) -> None :
80+ import pyarrow as pa
81+ query = "SELECT * FROM test_table"
82+ reader = read_sql (
83+ postgres_url ,
84+ query ,
85+ partition_on = "test_int" ,
86+ partition_range = (0 , 2000 ),
87+ partition_num = 3 ,
88+ return_type = "arrow_record_batches" ,
89+ record_batch_size = 2 ,
90+ )
91+ batches = []
92+ for batch in reader :
93+ batches .append (batch )
94+ table = pa .Table .from_batches (batches )
95+ df = table .to_pandas ()
96+ df .sort_values (by = "test_int" , inplace = True , ignore_index = True )
97+
98+ expected = pd .DataFrame (
99+ index = range (6 ),
100+ data = {
101+ "test_int" : pd .Series ([0 , 1 , 2 , 3 , 4 , 1314 ], dtype = "int64" ),
102+ "test_nullint" : pd .Series ([5 , 3 , None , 7 , 9 , 2 ], dtype = "float64" ),
103+ "test_str" : pd .Series (
104+ ["a" , "str1" , "str2" , "b" , "c" , None ], dtype = "object"
105+ ),
106+ "test_float" : pd .Series ([3.1 , None , 2.2 , 3 , 7.8 , - 10 ], dtype = "float64" ),
107+ "test_bool" : pd .Series (
108+ [None , True , False , False , None , True ], dtype = "object"
109+ ),
110+ },
111+ )
112+ assert_frame_equal (df , expected , check_names = True )
113+
47114def decimal_s10 (val ):
48115 return Decimal (val ).quantize (Decimal ("0.0000000001" ))
49116
0 commit comments