Skip to content

Commit c6836d8

Browse files
committed
add arrow stream test
1 parent 1b23062 commit c6836d8

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

connectorx-python/connectorx/tests/test_arrow.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,73 @@ def test_arrow(postgres_url: str) -> None:
4444
df.sort_values(by="test_int", inplace=True, ignore_index=True)
4545
assert_frame_equal(df, expected, check_names=True)
4646

47+
def test_arrow_stream(postgres_url: str) -> None:
48+
import pyarrow as pa
49+
query = "SELECT * FROM test_table"
50+
reader = read_sql(
51+
postgres_url,
52+
query,
53+
return_type="arrow_record_batches",
54+
record_batch_size=2,
55+
)
56+
batches = []
57+
for batch in reader:
58+
batches.append(batch)
59+
table = pa.Table.from_batches(batches)
60+
df = table.to_pandas()
61+
df.sort_values(by="test_int", inplace=True, ignore_index=True)
62+
63+
expected = pd.DataFrame(
64+
index=range(6),
65+
data={
66+
"test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"),
67+
"test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"),
68+
"test_str": pd.Series(
69+
["a", "str1", "str2", "b", "c", None], dtype="object"
70+
),
71+
"test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"),
72+
"test_bool": pd.Series(
73+
[None, True, False, False, None, True], dtype="object"
74+
),
75+
},
76+
)
77+
assert_frame_equal(df, expected, check_names=True)
78+
79+
def test_arrow_stream_with_partition(postgres_url: str) -> None:
80+
import pyarrow as pa
81+
query = "SELECT * FROM test_table"
82+
reader = read_sql(
83+
postgres_url,
84+
query,
85+
partition_on="test_int",
86+
partition_range=(0, 2000),
87+
partition_num=3,
88+
return_type="arrow_record_batches",
89+
record_batch_size=2,
90+
)
91+
batches = []
92+
for batch in reader:
93+
batches.append(batch)
94+
table = pa.Table.from_batches(batches)
95+
df = table.to_pandas()
96+
df.sort_values(by="test_int", inplace=True, ignore_index=True)
97+
98+
expected = pd.DataFrame(
99+
index=range(6),
100+
data={
101+
"test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"),
102+
"test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"),
103+
"test_str": pd.Series(
104+
["a", "str1", "str2", "b", "c", None], dtype="object"
105+
),
106+
"test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"),
107+
"test_bool": pd.Series(
108+
[None, True, False, False, None, True], dtype="object"
109+
),
110+
},
111+
)
112+
assert_frame_equal(df, expected, check_names=True)
113+
47114
def decimal_s10(val):
48115
return Decimal(val).quantize(Decimal("0.0000000001"))
49116

0 commit comments

Comments
 (0)