39
39
import copy
40
40
import functools
41
41
import os
42
+ import queue
42
43
import random
43
44
import re
44
45
import threading
45
46
import urllib.parse
47
+ from concurrent.futures import ThreadPoolExecutor
46
48
import warnings
47
49
from datetime import date, datetime, time, timedelta, timezone, tzinfo
48
50
from decimal import Decimal
49
51
from time import sleep
50
- from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
52
+ from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
51
53
52
54
import pytz
53
55
import requests
@@ -684,6 +686,27 @@ def _verify_extra_credential(self, header):
684
686
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
685
687
686
688
689
+ class ResultDownloader():
690
+ def __init__(self):
691
+ self.queue: queue.Queue = queue.Queue()
692
+ self.executor: Optional[ThreadPoolExecutor] = None
693
+
694
+ def submit(self, fetch_func: Callable[[], List[Any]]):
695
+ assert self.executor is not None
696
+ self.executor.submit(self.download_task, fetch_func)
697
+
698
+ def download_task(self, fetch_func):
699
+ self.queue.put(fetch_func())
700
+
701
+ def __enter__(self):
702
+ self.executor = ThreadPoolExecutor(max_workers=1)
703
+ return self
704
+
705
+ def __exit__(self, exc_type, exc_value, exc_traceback):
706
+ self.executor.shutdown()
707
+ self.executor = None
708
+
709
+
687
710
class TrinoResult(object):
688
711
"""
689
712
Represent the result of a Trino query as an iterator on rows.
@@ -711,16 +734,21 @@ def rownumber(self) -> int:
711
734
return self._rownumber
712
735
713
736
def __iter__(self):
714
- # A query only transitions to a FINISHED state when the results are fully consumed :
715
- # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
716
- while not self._query.finished or self._rows is not None:
717
- next_rows = self._query.fetch() if not self._query.finished else None
718
- for row in self._rows:
719
- self._rownumber += 1
720
- logger.debug("row %s", row)
721
- yield row
737
+ with ResultDownloader() as result_downloader :
738
+ # A query only transitions to a FINISHED state when the results are fully consumed:
739
+ # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
740
+ result_downloader.submit( self._query.fetch)
741
+ while not self._query.finished or self._rows is not None :
742
+ next_rows = result_downloader.queue.get() if not self._query.finished else None
743
+ if not self._query.finished:
744
+ result_downloader.submit(self._query.fetch)
722
745
723
- self._rows = next_rows
746
+ for row in self._rows:
747
+ self._rownumber += 1
748
+ logger.debug("row %s", row)
749
+ yield row
750
+
751
+ self._rows = next_rows
724
752
725
753
726
754
class TrinoQuery(object):
@@ -753,7 +781,7 @@ def columns(self):
753
781
while not self._columns and not self.finished and not self.cancelled:
754
782
# Columns are not returned immediately after query is submitted.
755
783
# Continue fetching data until columns information is available and push fetched rows into buffer.
756
- self._result.rows += self.fetch()
784
+ self._result.rows += self.map_rows(self. fetch() )
757
785
return self._columns
758
786
759
787
@property
@@ -802,7 +830,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
802
830
803
831
# Execute should block until at least one row is received or query is finished or cancelled
804
832
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
805
- self._result.rows += self.fetch()
833
+ self._result.rows += self.map_rows(self. fetch() )
806
834
return self._result
807
835
808
836
def _update_state(self, status):
@@ -822,11 +850,12 @@ def fetch(self) -> List[List[Any]]:
822
850
logger.debug(status)
823
851
if status.next_uri is None:
824
852
self._finished = True
853
+ return status.rows
825
854
855
+ def map_rows(self, rows: List[List[Any]]) -> List[List[Any]]:
826
856
if not self._row_mapper:
827
857
return []
828
-
829
- return self._row_mapper.map(status.rows)
858
+ return self._row_mapper.map(rows)
830
859
831
860
def cancel(self) -> None:
832
861
"""Cancel the current query"""
0 commit comments