|
9 | 9 | from multiprocessing.pool import ThreadPool |
10 | 10 | from threading import (Condition, Lock) |
11 | 11 |
|
| 12 | +from .compat import ITERATOR |
12 | 13 | from snowflake.connector.network import ResultIterWithTimings |
| 14 | +from snowflake.connector.gzip_decoder import decompress_raw_data |
| 15 | +from snowflake.connector.util_text import split_rows_from_stream |
13 | 16 | from .errorcode import (ER_NO_ADDITIONAL_CHUNK, ER_CHUNK_DOWNLOAD_FAILED) |
14 | 17 | from .errors import (Error, OperationalError) |
| 18 | +import json |
| 19 | +from io import StringIO |
| 20 | +from gzip import GzipFile |
| 21 | + |
| 22 | +try: |
| 23 | + from pyarrow.ipc import open_stream |
| 24 | +except ImportError: |
| 25 | + pass |
15 | 26 |
|
16 | 27 | DEFAULT_REQUEST_TIMEOUT = 3600 |
17 | 28 |
|
@@ -42,9 +53,11 @@ class SnowflakeChunkDownloader(object): |
42 | 53 | """ |
43 | 54 |
|
44 | 55 | def _pre_init(self, chunks, connection, cursor, qrmk, chunk_headers, |
| 56 | + query_result_format='JSON', |
45 | 57 | prefetch_threads=DEFAULT_CLIENT_PREFETCH_THREADS, |
46 | 58 | use_ijson=False): |
47 | 59 | self._use_ijson = use_ijson |
| 60 | + self._query_result_format = query_result_format |
48 | 61 |
|
49 | 62 | self._downloader_error = None |
50 | 63 |
|
@@ -87,9 +100,11 @@ def _pre_init(self, chunks, connection, cursor, qrmk, chunk_headers, |
87 | 100 | self._next_chunk_to_consume = 0 |
88 | 101 |
|
89 | 102 | def __init__(self, chunks, connection, cursor, qrmk, chunk_headers, |
| 103 | + query_result_format='JSON', |
90 | 104 | prefetch_threads=DEFAULT_CLIENT_PREFETCH_THREADS, |
91 | 105 | use_ijson=False): |
92 | 106 | self._pre_init(chunks, connection, cursor, qrmk, chunk_headers, |
| 107 | + query_result_format=query_result_format, |
93 | 108 | prefetch_threads=prefetch_threads, |
94 | 109 | use_ijson=use_ijson) |
95 | 110 | logger.debug('Chunk Downloader in memory') |
@@ -251,10 +266,95 @@ def _fetch_chunk(self, url, headers): |
251 | 266 | """ |
252 | 267 | Fetch the chunk from S3. |
253 | 268 | """ |
| 269 | + handler = JsonBinaryHandler(is_raw_binary_iterator=True, |
| 270 | + use_ijson=self._use_ijson) \ |
| 271 | + if self._query_result_format == 'json' else \ |
| 272 | + ArrowBinaryHandler() |
| 273 | + |
254 | 274 | return self._connection.rest.fetch( |
255 | 275 | u'get', url, headers, |
256 | 276 | timeout=DEFAULT_REQUEST_TIMEOUT, |
257 | 277 | is_raw_binary=True, |
258 | | - is_raw_binary_iterator=True, |
259 | | - use_ijson=self._use_ijson, |
| 278 | + binary_data_handler=handler, |
260 | 279 | return_timing_metrics=True) |
| 280 | + |
| 281 | + |
| 282 | +class RawBinaryDataHandler: |
| 283 | + """ |
| 284 | + Abstract class being passed to network.py to handle raw binary data |
| 285 | + """ |
| 286 | + def to_iterator(self, raw_data_fd): |
| 287 | + pass |
| 288 | + |
| 289 | + |
| 290 | +class JsonBinaryHandler(RawBinaryDataHandler): |
| 291 | + """ |
| 292 | + Convert result chunk in json format into interator |
| 293 | + """ |
| 294 | + def __init__(self, is_raw_binary_iterator, use_ijson): |
| 295 | + self._is_raw_binary_iterator = is_raw_binary_iterator |
| 296 | + self._use_ijson = use_ijson |
| 297 | + |
| 298 | + def to_iterator(self, raw_data_fd): |
| 299 | + raw_data = decompress_raw_data( |
| 300 | + raw_data_fd, add_bracket=True |
| 301 | + ).decode('utf-8', 'replace') |
| 302 | + if not self._is_raw_binary_iterator: |
| 303 | + ret = json.loads(raw_data) |
| 304 | + elif not self._use_ijson: |
| 305 | + ret = iter(json.loads(raw_data)) |
| 306 | + else: |
| 307 | + ret = split_rows_from_stream(StringIO(raw_data)) |
| 308 | + return ret |
| 309 | + |
| 310 | + |
| 311 | +class ArrowBinaryHandler(RawBinaryDataHandler): |
| 312 | + """ |
| 313 | + Handler to consume data as arrow stream |
| 314 | + """ |
| 315 | + def to_iterator(self, raw_data_fd): |
| 316 | + gzip_decoder = GzipFile(fileobj=raw_data_fd, mode='r') |
| 317 | + reader = open_stream(gzip_decoder) |
| 318 | + return ArrowChunkIterator(reader) |
| 319 | + |
| 320 | + |
| 321 | +class ArrowChunkIterator(ITERATOR): |
| 322 | + """ |
| 323 | + Given a list of record batches, iterate over |
| 324 | + these batches row by row |
| 325 | + """ |
| 326 | + |
| 327 | + def __init__(self, arrow_stream_reader): |
| 328 | + self._batches = [] |
| 329 | + for record_batch in arrow_stream_reader: |
| 330 | + self._batches.append(record_batch.columns) |
| 331 | + |
| 332 | + self._column_count = len(self._batches[0]) |
| 333 | + self._batch_count = len(self._batches) |
| 334 | + self._batch_index = -1 |
| 335 | + self._index_in_batch = -1 |
| 336 | + self._row_count_in_batch = 0 |
| 337 | + |
| 338 | + def next(self): |
| 339 | + return self.__next__() |
| 340 | + |
| 341 | + def __next__(self): |
| 342 | + self._index_in_batch += 1 |
| 343 | + if self._index_in_batch < self._row_count_in_batch: |
| 344 | + return self._return_row() |
| 345 | + else: |
| 346 | + self._batch_index += 1 |
| 347 | + if self._batch_index < self._batch_count: |
| 348 | + self._index_in_batch = 0 |
| 349 | + self._row_count_in_batch = len(self._batches[self._batch_index][0]) |
| 350 | + return self._return_row() |
| 351 | + |
| 352 | + raise StopIteration |
| 353 | + |
| 354 | + def _return_row(self): |
| 355 | + ret = [] |
| 356 | + current_batch = self._batches[self._batch_index] |
| 357 | + for col_array in current_batch: |
| 358 | + ret.append(col_array[self._index_in_batch]) |
| 359 | + |
| 360 | + return ret |
0 commit comments