Skip to content

Commit e4221e5

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-98725: new fetch pandas API added to client perf test
1 parent 3ddf772 commit e4221e5

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

test/test_arrow_pandas.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def validate_pandas(conn_cnx, sql, cases, col_count, method='one', data_type='fl
508508
reason="arrow_iterator extension is not built.")
509509
def test_num_batch(conn_cnx):
510510
print('Test fetching dataframes in batch')
511-
row_count = 50000
511+
row_count = 1000000
512512
col_count = 2
513513
random_seed = get_random_seed()
514514
sql_exec = ("select seq4() as c1, uniform(1, 10, random({})) as c2 from ".format(random_seed) +
@@ -559,7 +559,18 @@ def fetch_pandas(conn_cnx, sql, row_count, col_count, method='one'):
559559
# actually its exec time would be different from `pd.read_sql()` via sqlalchemy as most people use
560560
# further perf test can be done separately
561561
start_time = time.time()
562-
df_old = pd.DataFrame(cursor_row.fetchall(), columns=['c{}'.format(i) for i in range(col_count)])
562+
rows = 0
563+
if method == 'one':
564+
df_old = pd.DataFrame(cursor_row.fetchall(), columns=['c{}'.format(i) for i in range(col_count)])
565+
else:
566+
print("use fetchmany")
567+
while True:
568+
dat = cursor_row.fetchmany(10000)
569+
if not dat:
570+
break
571+
else:
572+
df_old = pd.DataFrame(dat, columns=['c{}'.format(i) for i in range(col_count)])
573+
rows += df_old.shape[0]
563574
end_time = time.time()
564575
print('The original way took {}s'.format(end_time - start_time))
565576
cursor_row.close()
@@ -598,6 +609,8 @@ def fetch_pandas(conn_cnx, sql, row_count, col_count, method='one'):
598609
for j, (c_old, c_new) in enumerate(zip(col_old, col_new)):
599610
assert c_old == c_new, '{} row, {} column: old value is {}, new value is {}, \
600611
values are not equal'.format(i, j, c_old, c_new)
612+
else:
613+
assert rows == total_rows, 'the number of rows are not equal {} vs {}'.format(rows, total_rows)
601614

602615

603616
def init(conn_cnx, table, column, values, timezone=None):

0 commit comments

Comments
 (0)