Skip to content

Commit c542091

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-83333: Updated arrow result set iterator
1 parent 3ad9acf commit c542091

File tree

9 files changed

+284
-68
lines changed

9 files changed

+284
-68
lines changed

arrow_iterator.pyx

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#
2+
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
3+
#
4+
5+
from decimal import Context
6+
from logging import getLogger
7+
from datetime import datetime, timedelta, date
8+
9+
logger = getLogger(__name__)
10+
11+
ZERO_EPOCH = datetime.utcfromtimestamp(0)
12+
13+
cdef class ArrowChunkIterator:
14+
15+
cdef:
16+
list _batches
17+
int _column_count
18+
int _batch_count
19+
int _batch_index
20+
int _index_in_batch
21+
int _row_count_in_batch
22+
list _current_batch
23+
24+
def __init__(self, arrow_stream_reader, meta):
25+
self._batches = []
26+
for record_batch in arrow_stream_reader:
27+
converters = []
28+
for index, column in enumerate(record_batch.columns):
29+
converters.append(ColumnConverter.init_converter(column, meta[index]))
30+
self._batches.append(converters)
31+
32+
self._column_count = len(self._batches[0])
33+
self._batch_count = len(self._batches)
34+
self._batch_index = -1
35+
self._index_in_batch = -1
36+
self._row_count_in_batch = 0
37+
self._current_batch = None
38+
39+
def next(self):
40+
return self.__next__()
41+
42+
def __next__(self):
43+
self._index_in_batch += 1
44+
if self._index_in_batch < self._row_count_in_batch:
45+
return self._return_row()
46+
else:
47+
self._batch_index += 1
48+
if self._batch_index < self._batch_count:
49+
self._current_batch = self._batches[self._batch_index]
50+
self._index_in_batch = 0
51+
self._row_count_in_batch = self._current_batch[0].row_count()
52+
return self._return_row()
53+
54+
raise StopIteration
55+
56+
cdef _return_row(self):
57+
row = []
58+
for col in self._current_batch:
59+
row.append(col.to_python_native(self._index_in_batch))
60+
61+
return row
62+
63+
64+
cdef class ColumnConverter:
65+
#Convert from arrow data into python native data types
66+
67+
cdef object _arrow_column_array
68+
cdef object _meta
69+
70+
def __init__(self, arrow_column_array, meta):
71+
"""
72+
Base Column Converter constructor
73+
:param arrow_column_array: arrow array
74+
:param meta: column metadata, which is a tuple with same form as cursor.description
75+
"""
76+
self._arrow_column_array = arrow_column_array
77+
self._meta = meta
78+
79+
def to_python_native(self, index):
80+
return self._arrow_column_array[index].as_py()
81+
82+
def row_count(self):
83+
return len(self._arrow_column_array)
84+
85+
@staticmethod
86+
def init_converter(column_array, meta):
87+
# index 1 is type code
88+
if meta[1] == 'FIXED':
89+
return FixedColumnConverter(column_array, meta)
90+
else:
91+
return ColumnConverter(column_array, meta)
92+
93+
cdef class FixedColumnConverter(ColumnConverter):
94+
cdef int _scale
95+
cdef object _convert_method
96+
97+
def __init__(self, arrow_column_array, meta):
98+
super().__init__(arrow_column_array, meta)
99+
self._scale = meta[5]
100+
if self._scale == 0:
101+
self._convert_method = self._to_int
102+
else:
103+
self._decimal_ctx = Context(prec=meta['precision'])
104+
self._convert_method = self._to_decimal
105+
106+
def to_python_native(self, index):
107+
val = self._arrow_column_array[index]
108+
return self._convert_method(val)
109+
110+
def _to_int(self, val):
111+
return val.as_py()
112+
113+
def _to_decimal(self, val):
114+
return 0
115+
116+
cdef class DateColumnConverter(ColumnConverter):
117+
118+
def __init__(self, arrow_column_array, meta):
119+
super().__init__(arrow_column_array, meta)
120+
121+
def to_python_native(self, index):
122+
value = self._arrow_column_array[index]
123+
try:
124+
return datetime.utcfromtimestamp(value.as_py() * 86400).date()
125+
except OSError as e:
126+
logger.debug("Failed to convert: %s", e)
127+
ts = ZERO_EPOCH + timedelta(
128+
seconds=value * (24 * 60 * 60))
129+
return date(ts.year, ts.month, ts.day)

chunk_downloader.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from multiprocessing.pool import ThreadPool
1010
from threading import (Condition, Lock)
1111

12-
from .compat import ITERATOR
1312
from snowflake.connector.network import ResultIterWithTimings
1413
from snowflake.connector.gzip_decoder import decompress_raw_data
1514
from snowflake.connector.util_text import split_rows_from_stream
@@ -21,6 +20,7 @@
2120

2221
try:
2322
from pyarrow.ipc import open_stream
23+
from .arrow_iterator import ArrowChunkIterator
2424
except ImportError:
2525
pass
2626

@@ -269,7 +269,7 @@ def _fetch_chunk(self, url, headers):
269269
handler = JsonBinaryHandler(is_raw_binary_iterator=True,
270270
use_ijson=self._use_ijson) \
271271
if self._query_result_format == 'json' else \
272-
ArrowBinaryHandler()
272+
ArrowBinaryHandler(self._cursor.description)
273273

274274
return self._connection.rest.fetch(
275275
u'get', url, headers,
@@ -309,52 +309,14 @@ def to_iterator(self, raw_data_fd):
309309

310310

311311
class ArrowBinaryHandler(RawBinaryDataHandler):
312+
313+
def __init__(self, meta):
314+
self._meta = meta
315+
312316
"""
313317
Handler to consume data as arrow stream
314318
"""
315319
def to_iterator(self, raw_data_fd):
316320
gzip_decoder = GzipFile(fileobj=raw_data_fd, mode='r')
317321
reader = open_stream(gzip_decoder)
318-
return ArrowChunkIterator(reader)
319-
320-
321-
class ArrowChunkIterator(ITERATOR):
322-
"""
323-
Given a list of record batches, iterate over
324-
these batches row by row
325-
"""
326-
327-
def __init__(self, arrow_stream_reader):
328-
self._batches = []
329-
for record_batch in arrow_stream_reader:
330-
self._batches.append(record_batch.columns)
331-
332-
self._column_count = len(self._batches[0])
333-
self._batch_count = len(self._batches)
334-
self._batch_index = -1
335-
self._index_in_batch = -1
336-
self._row_count_in_batch = 0
337-
338-
def next(self):
339-
return self.__next__()
340-
341-
def __next__(self):
342-
self._index_in_batch += 1
343-
if self._index_in_batch < self._row_count_in_batch:
344-
return self._return_row()
345-
else:
346-
self._batch_index += 1
347-
if self._batch_index < self._batch_count:
348-
self._index_in_batch = 0
349-
self._row_count_in_batch = len(self._batches[self._batch_index][0])
350-
return self._return_row()
351-
352-
raise StopIteration
353-
354-
def _return_row(self):
355-
ret = []
356-
current_batch = self._batches[self._batch_index]
357-
for col_array in current_batch:
358-
ret.append(col_array[self._index_in_batch])
359-
360-
return ret
322+
return ArrowChunkIterator(reader, self._meta)

cursor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
from .sqlstate import (SQLSTATE_FEATURE_NOT_SUPPORTED)
3131
from .telemetry import (TelemetryData, TelemetryField)
3232
from .time_util import get_time_millis
33-
from .chunk_downloader import ArrowChunkIterator
34-
from .converter_arrow import SnowflakeArrowConverter
3533

3634
try:
3735
from pyarrow.ipc import open_stream
36+
from .arrow_iterator import ArrowChunkIterator
3837
except ImportError:
3938
pass
4039

@@ -587,8 +586,8 @@ def chunk_info(self, data, use_ijson=False):
587586
self._column_idx_to_name = {}
588587
self._column_converter = []
589588

590-
converter = SnowflakeArrowConverter() if \
591-
self._query_result_format == 'arrow' else self._connection.converter
589+
self._return_row_method = self._arrow_return_row if \
590+
self._query_result_format == 'arrow' else self._json_return_row
592591

593592
for idx, column in enumerate(data[u'rowtype']):
594593
self._column_idx_to_name[idx] = column[u'name']
@@ -601,7 +600,7 @@ def chunk_info(self, data, use_ijson=False):
601600
column[u'scale'],
602601
column[u'nullable']))
603602
self._column_converter.append(
604-
converter.to_python_method(
603+
self._connection.converter.to_python_method(
605604
column[u'type'].upper(), column))
606605

607606
self._total_row_index = -1 # last fetched number of rows
@@ -612,7 +611,7 @@ def chunk_info(self, data, use_ijson=False):
612611
# result as arrow chunk
613612
arrow_bytes = b64decode(data.get(u'rowsetBase64'))
614613
arrow_reader = open_stream(arrow_bytes)
615-
self._current_chunk_row = ArrowChunkIterator(arrow_reader)
614+
self._current_chunk_row = ArrowChunkIterator(arrow_reader, self._description)
616615
else:
617616
self._current_chunk_row = iter(data.get(u'rowset'))
618617
self._current_chunk_row_count = len(data.get(u'rowset'))
@@ -801,7 +800,7 @@ def fetchone(self):
801800
self._current_chunk_row = iter(())
802801
is_done = True
803802

804-
return self._row_to_python(row) if row is not None else None
803+
return self._return_row_method(row)
805804

806805
except IndexError:
807806
# returns None if the iteration is completed so that iter() stops
@@ -814,6 +813,12 @@ def fetchone(self):
814813
TelemetryField.TIME_CONSUME_LAST_RESULT,
815814
time_consume_last_result)
816815

816+
def _json_return_row(self, row):
817+
return self._row_to_python(row) if row is not None else None
818+
819+
def _arrow_return_row(self, row):
820+
return row
821+
817822
def fetchmany(self, size=None):
818823
u"""
819824
Fetch the number of specified rows

ocsp_snowflake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def find_cache(ocsp, cert_id, subject):
560560
return True, cache
561561
else:
562562
OCSPCache.delete_cache(ocsp, cert_id)
563-
except Exception as ex:
563+
except Exception:
564564
OCSPCache.delete_cache(ocsp, cert_id)
565565
else:
566566
logger.debug("Could not validate cache entry %s %s",

setup.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
#
66
from codecs import open
77
from os import path
8+
import os
89

910
from setuptools import setup
11+
from os.path import join
1012

1113
THIS_DIR = path.dirname(path.realpath(__file__))
1214

@@ -19,19 +21,31 @@
1921
with open(path.join(THIS_DIR, 'DESCRIPTION.rst'), encoding='utf-8') as f:
2022
long_description = f.read()
2123

24+
cython_build_dir = join("build", "cython")
25+
cython_source = [
26+
"arrow_iterator.pyx"
27+
]
28+
enable_ext_modules = os.environ.get("ENABLE_EXT_MODULES", "false")
29+
ext_modules = None
30+
if enable_ext_modules == "true":
31+
from Cython.Build import cythonize
32+
ext_modules = cythonize(cython_source, build_dir=cython_build_dir)
33+
2234
setup(
2335
name='snowflake-connector-python',
2436
version=version,
2537
description=u"Snowflake Connector for Python",
38+
ext_modules=ext_modules,
2639
long_description=long_description,
27-
author='Snowflake Computing, Inc',
28-
author_email='support@snowflake.net',
40+
author='Snowflake, Inc',
41+
author_email='support@snowflake.com',
2942
license='Apache License, Version 2.0',
3043
keywords="Snowflake db database cloud analytics warehouse",
31-
url='https://www.snowflake.net/',
32-
download_url='https://www.snowflake.net/',
44+
url='https://www.snowflake.com/',
45+
download_url='https://www.snowflake.com/',
3346
use_2to3=False,
3447

48+
# NOTE: Python 3.4 will be dropped within one month.
3549
python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*',
3650

3751
install_requires=[
@@ -108,7 +122,6 @@
108122

109123
'Programming Language :: SQL',
110124
'Programming Language :: Python :: 2.7',
111-
'Programming Language :: Python :: 3.4',
112125
'Programming Language :: Python :: 3.5',
113126
'Programming Language :: Python :: 3.6',
114127
'Programming Language :: Python :: 3.7',

ssl_wrap_socket.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def recv(self, *args, **kwargs):
201201
return b''
202202
else:
203203
raise SocketError(str(e))
204-
except OpenSSL.SSL.ZeroReturnError as e:
204+
except OpenSSL.SSL.ZeroReturnError:
205205
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
206206
return b''
207207
else:
@@ -223,7 +223,7 @@ def recv_into(self, *args, **kwargs):
223223
return 0
224224
else:
225225
raise SocketError(str(e))
226-
except OpenSSL.SSL.ZeroReturnError as e:
226+
except OpenSSL.SSL.ZeroReturnError:
227227
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
228228
return 0
229229
else:

test/test_arrow_result.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
5+
#
6+
7+
import pytest
8+
9+
@pytest.mark.skip(
10+
reason="Cython is not enabled in build env")
11+
def test_select_with_num(conn_cnx):
12+
with conn_cnx() as json_cnx:
13+
with conn_cnx() as arrow_cnx:
14+
row_count = 50000
15+
sql_text = ("select seq4() as c1, uniform(1, 10, random(12)) as c2 from " +
16+
"table(generator(rowcount=>50000)) order by c1")
17+
cursor_json = json_cnx.cursor()
18+
cursor_json.execute("alter session set query_result_format='JSON'")
19+
cursor_json.execute(sql_text)
20+
21+
cursor_arrow = arrow_cnx.cursor()
22+
cursor_arrow.execute("alter session set query_result_format='ARROW_FORCE'")
23+
cursor_arrow.execute(sql_text)
24+
25+
for i in range(0, row_count):
26+
(json_c1, json_c2) = cursor_json.fetchone()
27+
(arrow_c1, arrow_c2) = cursor_arrow.fetchone()
28+
assert json_c1 == arrow_c1
29+
assert json_c2 == arrow_c2

0 commit comments

Comments
 (0)