Skip to content

Commit ccba54f

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-83333: Support arrow result format in chunk downloader.
1 parent 805e181 commit ccba54f

File tree

9 files changed

+279
-21
lines changed

9 files changed

+279
-21
lines changed

chunk_downloader.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,20 @@
99
from multiprocessing.pool import ThreadPool
1010
from threading import (Condition, Lock)
1111

12+
from .compat import ITERATOR
1213
from snowflake.connector.network import ResultIterWithTimings
14+
from snowflake.connector.gzip_decoder import decompress_raw_data
15+
from snowflake.connector.util_text import split_rows_from_stream
1316
from .errorcode import (ER_NO_ADDITIONAL_CHUNK, ER_CHUNK_DOWNLOAD_FAILED)
1417
from .errors import (Error, OperationalError)
18+
import json
19+
from io import StringIO
20+
from gzip import GzipFile
21+
22+
try:
23+
from pyarrow.ipc import open_stream
24+
except ImportError:
25+
pass
1526

1627
DEFAULT_REQUEST_TIMEOUT = 3600
1728

@@ -42,9 +53,11 @@ class SnowflakeChunkDownloader(object):
4253
"""
4354

4455
def _pre_init(self, chunks, connection, cursor, qrmk, chunk_headers,
56+
query_result_format='JSON',
4557
prefetch_threads=DEFAULT_CLIENT_PREFETCH_THREADS,
4658
use_ijson=False):
4759
self._use_ijson = use_ijson
60+
self._query_result_format = query_result_format
4861

4962
self._downloader_error = None
5063

@@ -87,9 +100,11 @@ def _pre_init(self, chunks, connection, cursor, qrmk, chunk_headers,
87100
self._next_chunk_to_consume = 0
88101

89102
def __init__(self, chunks, connection, cursor, qrmk, chunk_headers,
103+
query_result_format='JSON',
90104
prefetch_threads=DEFAULT_CLIENT_PREFETCH_THREADS,
91105
use_ijson=False):
92106
self._pre_init(chunks, connection, cursor, qrmk, chunk_headers,
107+
query_result_format=query_result_format,
93108
prefetch_threads=prefetch_threads,
94109
use_ijson=use_ijson)
95110
logger.debug('Chunk Downloader in memory')
@@ -251,10 +266,95 @@ def _fetch_chunk(self, url, headers):
251266
"""
252267
Fetch the chunk from S3.
253268
"""
269+
handler = JsonBinaryHandler(is_raw_binary_iterator=True,
270+
use_ijson=self._use_ijson) \
271+
if self._query_result_format == 'json' else \
272+
ArrowBinaryHandler()
273+
254274
return self._connection.rest.fetch(
255275
u'get', url, headers,
256276
timeout=DEFAULT_REQUEST_TIMEOUT,
257277
is_raw_binary=True,
258-
is_raw_binary_iterator=True,
259-
use_ijson=self._use_ijson,
278+
binary_data_handler=handler,
260279
return_timing_metrics=True)
280+
281+
282+
class RawBinaryDataHandler:
283+
"""
284+
Abstract class being passed to network.py to handle raw binary data
285+
"""
286+
def to_iterator(self, raw_data_fd):
287+
pass
288+
289+
290+
class JsonBinaryHandler(RawBinaryDataHandler):
291+
"""
292+
Convert result chunk in json format into interator
293+
"""
294+
def __init__(self, is_raw_binary_iterator, use_ijson):
295+
self._is_raw_binary_iterator = is_raw_binary_iterator
296+
self._use_ijson = use_ijson
297+
298+
def to_iterator(self, raw_data_fd):
299+
raw_data = decompress_raw_data(
300+
raw_data_fd, add_bracket=True
301+
).decode('utf-8', 'replace')
302+
if not self._is_raw_binary_iterator:
303+
ret = json.loads(raw_data)
304+
elif not self._use_ijson:
305+
ret = iter(json.loads(raw_data))
306+
else:
307+
ret = split_rows_from_stream(StringIO(raw_data))
308+
return ret
309+
310+
311+
class ArrowBinaryHandler(RawBinaryDataHandler):
312+
"""
313+
Handler to consume data as arrow stream
314+
"""
315+
def to_iterator(self, raw_data_fd):
316+
gzip_decoder = GzipFile(fileobj=raw_data_fd, mode='r')
317+
reader = open_stream(gzip_decoder)
318+
return ArrowChunkIterator(reader)
319+
320+
321+
class ArrowChunkIterator(ITERATOR):
322+
"""
323+
Given a list of record batches, iterate over
324+
these batches row by row
325+
"""
326+
327+
def __init__(self, arrow_stream_reader):
328+
self._batches = []
329+
for record_batch in arrow_stream_reader:
330+
self._batches.append(record_batch.columns)
331+
332+
self._column_count = len(self._batches[0])
333+
self._batch_count = len(self._batches)
334+
self._batch_index = -1
335+
self._index_in_batch = -1
336+
self._row_count_in_batch = 0
337+
338+
def next(self):
339+
return self.__next__()
340+
341+
def __next__(self):
342+
self._index_in_batch += 1
343+
if self._index_in_batch < self._row_count_in_batch:
344+
return self._return_row()
345+
else:
346+
self._batch_index += 1
347+
if self._batch_index < self._batch_count:
348+
self._index_in_batch = 0
349+
self._row_count_in_batch = len(self._batches[self._batch_index][0])
350+
return self._return_row()
351+
352+
raise StopIteration
353+
354+
def _return_row(self):
355+
ret = []
356+
current_batch = self._batches[self._batch_index]
357+
for col_array in current_batch:
358+
ret.append(col_array[self._index_in_batch])
359+
360+
return ret

converter_arrow.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
5+
#
6+
7+
from logging import getLogger
8+
from decimal import Context
9+
from datetime import datetime, timedelta, date
10+
from .converter import SnowflakeConverter
11+
12+
logger = getLogger(__name__)
13+
14+
ZERO_EPOCH = datetime.utcfromtimestamp(0)
15+
16+
17+
class SnowflakeArrowConverter(SnowflakeConverter):
18+
"""
19+
Convert from arrow data into python native data types
20+
"""
21+
22+
def to_python_method(self, type_name, column):
23+
ctx = column.copy()
24+
25+
if type_name == 'FIXED' and ctx['scale'] != 0:
26+
ctx['decimalCtx'] = Context(prec=ctx['precision'])
27+
28+
converters = [u'_{type_name}_to_python'.format(type_name=type_name)]
29+
if self._use_numpy:
30+
converters.insert(0, u'_{type_name}_numpy_to_python'.format(
31+
type_name=type_name))
32+
for conv in converters:
33+
try:
34+
return getattr(self, conv)(ctx)
35+
except AttributeError:
36+
pass
37+
logger.warning(
38+
"No column converter found for type: %s", type_name)
39+
return None # Skip conversion
40+
41+
def _FIXED_to_python(self, ctx):
42+
if ctx['scale'] == 0:
43+
return lambda x: x.as_py()
44+
else:
45+
return lambda x, decimal_ctx=ctx['decimalCtx']: decimal_ctx.create_decimal(x.as_py())
46+
47+
def _REAL_to_python(self, _):
48+
return lambda x: x.as_py()
49+
50+
def _TEXT_to_python(self, _):
51+
return lambda x: x.as_py()
52+
53+
def _BINARY_to_python(self, _):
54+
return lambda x: x.as_py()
55+
56+
def _VARIANT_to_python(self, _):
57+
return lambda x: x.as_py()
58+
59+
def _BOOLEAN_to_python(self, _):
60+
return lambda x: x.as_py() > 0
61+
62+
def _DATE_to_python(self, _):
63+
64+
def conv(value):
65+
try:
66+
return datetime.utcfromtimestamp(value * 86400).date()
67+
except OSError as e:
68+
logger.debug("Failed to convert: %s", e)
69+
ts = ZERO_EPOCH + timedelta(
70+
seconds=value * (24 * 60 * 60))
71+
return date(ts.year, ts.month, ts.day)
72+
73+
return conv

cursor.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import uuid
1111
from logging import getLogger
1212
from threading import (Timer, Lock)
13-
13+
from base64 import b64decode
1414
from six import u
1515

1616
from .compat import (BASE_EXCEPTION_CLASS)
@@ -30,6 +30,13 @@
3030
from .sqlstate import (SQLSTATE_FEATURE_NOT_SUPPORTED)
3131
from .telemetry import (TelemetryData, TelemetryField)
3232
from .time_util import get_time_millis
33+
from .chunk_downloader import ArrowChunkIterator
34+
from .converter_arrow import SnowflakeArrowConverter
35+
36+
try:
37+
from pyarrow.ipc import open_stream
38+
except ImportError:
39+
pass
3340

3441
STATEMENT_TYPE_ID_DML = 0x3000
3542
STATEMENT_TYPE_ID_INSERT = STATEMENT_TYPE_ID_DML + 0x100
@@ -570,6 +577,7 @@ def _is_dml(self, data):
570577

571578
def chunk_info(self, data, use_ijson=False):
572579
is_dml = self._is_dml(data)
580+
self._query_result_format = data.get(u'queryResultFormat', u'json')
573581

574582
if self._total_rowcount == -1 and not is_dml and data.get(u'total') \
575583
is not None:
@@ -578,6 +586,10 @@ def chunk_info(self, data, use_ijson=False):
578586
self._description = []
579587
self._column_idx_to_name = {}
580588
self._column_converter = []
589+
590+
converter = SnowflakeArrowConverter() if \
591+
self._query_result_format == 'arrow' else self._connection.converter
592+
581593
for idx, column in enumerate(data[u'rowtype']):
582594
self._column_idx_to_name[idx] = column[u'name']
583595
type_value = FIELD_NAME_TO_ID[column[u'type'].upper()]
@@ -589,15 +601,21 @@ def chunk_info(self, data, use_ijson=False):
589601
column[u'scale'],
590602
column[u'nullable']))
591603
self._column_converter.append(
592-
self._connection.converter.to_python_method(
593-
column[u'type'].upper(), column))
604+
converter.to_python_method(
605+
column[u'type'].upper(), column))
594606

595607
self._total_row_index = -1 # last fetched number of rows
596608

597609
self._chunk_index = 0
598610
self._chunk_count = 0
599-
self._current_chunk_row = iter(data.get(u'rowset'))
600-
self._current_chunk_row_count = len(data.get(u'rowset'))
611+
if self._query_result_format == 'arrow':
612+
# result as arrow chunk
613+
arrow_bytes = b64decode(data.get(u'rowsetBase64'))
614+
arrow_reader = open_stream(arrow_bytes)
615+
self._current_chunk_row = ArrowChunkIterator(arrow_reader)
616+
else:
617+
self._current_chunk_row = iter(data.get(u'rowset'))
618+
self._current_chunk_row_count = len(data.get(u'rowset'))
601619

602620
if u'chunks' in data:
603621
chunks = data[u'chunks']
@@ -619,6 +637,7 @@ def chunk_info(self, data, use_ijson=False):
619637
logger.debug(u'qrmk=%s', qrmk)
620638
self._chunk_downloader = self._connection._chunk_downloader_class(
621639
chunks, self._connection, self, qrmk, chunk_headers,
640+
query_result_format=self._query_result_format,
622641
prefetch_threads=self._connection.client_prefetch_threads,
623642
use_ijson=use_ijson)
624643

network.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
import time
1515
import uuid
16-
from io import StringIO, BytesIO
16+
from io import BytesIO
1717
from threading import Lock
1818

1919
import OpenSSL.SSL
@@ -47,7 +47,6 @@
4747
InterfaceError, InternalServerError, ForbiddenError,
4848
BadGatewayError, BadRequest, MethodNotAllowed,
4949
OtherHTTPRetryableError)
50-
from .gzip_decoder import decompress_raw_data
5150
from .sqlstate import (SQLSTATE_CONNECTION_NOT_EXISTS,
5251
SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
5352
SQLSTATE_CONNECTION_REJECTED)
@@ -56,7 +55,6 @@
5655
DEFAULT_MASTER_VALIDITY_IN_SECONDS
5756
)
5857
from .tool.probe_connection import probe_connection
59-
from .util_text import split_rows_from_stream
6058
from .version import VERSION
6159

6260
if PY2:
@@ -739,8 +737,7 @@ def _request_exec(
739737
catch_okta_unauthorized_error=False,
740738
is_raw_text=False,
741739
is_raw_binary=False,
742-
is_raw_binary_iterator=True,
743-
use_ijson=False,
740+
binary_data_handler=None,
744741
socket_timeout=DEFAULT_SOCKET_CONNECT_TIMEOUT,
745742
return_timing_metrics=False):
746743
if socket_timeout > DEFAULT_SOCKET_CONNECT_TIMEOUT:
@@ -785,15 +782,7 @@ def _request_exec(
785782
ret = raw_ret.text
786783
elif is_raw_binary:
787784
start_time = get_time_millis()
788-
raw_data = decompress_raw_data(
789-
raw_ret.raw, add_bracket=True
790-
).decode('utf-8', 'replace')
791-
if not is_raw_binary_iterator:
792-
ret = json.loads(raw_data)
793-
elif not use_ijson:
794-
ret = iter(json.loads(raw_data))
795-
else:
796-
ret = split_rows_from_stream(StringIO(raw_data))
785+
ret = binary_data_handler.to_iterator(raw_ret.raw)
797786
timing_metrics[
798787
ResultIterWithTimings.PARSE] = get_time_millis() - start_time
799788

scripts/install.bat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ call env\Scripts\activate
77
# https://github.com/pypa/pip/issues/6566
88
python -m pip install --upgrade pip==18.1
99
pip install pendulum
10+
pip install pyarrow
1011
pip install numpy
1112
pip install pytest pytest-cov pytest-rerunfailures
1213
pip install .

scripts/install.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ fi
3131

3232
source ./venv/bin/activate
3333
pip install numpy pendulum
34+
pip install pyarrow
3435
pip install pytest pytest-cov pytest-rerunfailures
3536
if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]] || [[ $PYTHON_VERSION == "2.7"* ]]; then
3637
pip install mock

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@
8585
"secure-local-storage": [
8686
'keyring!=16.1.0'
8787
],
88+
"arrow-result": [
89+
'pyarrow>=0.13.0;python_version>"3.4"',
90+
'pyarrow>=0.13.0;python_version<"3.0"'
91+
]
8892
},
8993

9094
classifiers=[

test/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def get_db_parameters():
126126
ret['name_wh'] = ret['name'] + 'wh'
127127

128128
ret['schema'] = TEST_SCHEMA
129+
130+
# This reduces a chance to exposing password in test output.
129131
ret['a00'] = 'dummy parameter'
130132
ret['a01'] = 'dummy parameter'
131133
ret['a02'] = 'dummy parameter'

0 commit comments

Comments
 (0)