@@ -10,7 +10,8 @@ from .telemetry import TelemetryField
1010from .time_util import get_time_millis
1111try :
1212 from pyarrow.ipc import open_stream
13- from .arrow_iterator import PyArrowChunkIterator
13+ from pyarrow import concat_tables
14+ from .arrow_iterator import PyArrowIterator, ROW_UNIT, TABLE_UNIT, EMPTY_UNIT
1415 from .arrow_context import ArrowConverterContext
1516except ImportError :
1617 pass
@@ -32,6 +33,7 @@ cdef class ArrowResult:
3233 object _current_chunk_row
3334 object _chunk_downloader
3435 object _arrow_context
36+ str _iter_unit
3537
3638 def __init__ (self , raw_response , cursor ):
3739 self ._reset()
@@ -51,9 +53,10 @@ cdef class ArrowResult:
5153 arrow_bytes = b64decode(rowset_b64)
5254 arrow_reader = open_stream(arrow_bytes)
5355 self ._arrow_context = ArrowConverterContext(self ._connection._session_parameters)
54- self ._current_chunk_row = PyArrowChunkIterator (arrow_reader, self ._arrow_context)
56+ self ._current_chunk_row = PyArrowIterator (arrow_reader, self ._arrow_context)
5557 else :
56- self ._current_chunk_row = iter ([])
58+ self ._current_chunk_row = iter (())
59+ self ._iter_unit = EMPTY_UNIT
5760
5861 if u ' chunks' in data:
5962 chunks = data[u ' chunks' ]
@@ -83,6 +86,13 @@ cdef class ArrowResult:
8386 return self
8487
8588 def __next__ (self ):
89+ if self ._iter_unit == EMPTY_UNIT:
90+ self ._iter_unit = ROW_UNIT
91+ self ._current_chunk_row.init(self ._iter_unit)
92+ elif self ._iter_unit == TABLE_UNIT:
93+ logger.debug(u ' The iterator has been built for fetching arrow table' )
94+ raise RuntimeError
95+
8696 is_done = False
8797 try :
8898 row = None
@@ -96,6 +106,7 @@ cdef class ArrowResult:
96106 self ._chunk_index, self ._chunk_count)
97107 next_chunk = self ._chunk_downloader.next_chunk()
98108 self ._current_chunk_row = next_chunk.result_data
109+ self ._current_chunk_row.init(self ._iter_unit)
99110 self ._chunk_index += 1
100111 try :
101112 row = self ._current_chunk_row.__next__()
@@ -146,4 +157,88 @@ cdef class ArrowResult:
146157 self ._chunk_count = 0
147158 self ._chunk_downloader = None
148159 self ._arrow_context = None
160+ self ._iter_unit = EMPTY_UNIT
161+
162+ def _fetch_arrow_batches (self ):
163+ '''
164+ Fetch Arrow Table in batch, where 'batch' refers to Snowflake Chunk
165+ Thus, the batch size (the number of rows in table) may be different
166+ '''
167+ if self ._iter_unit == EMPTY_UNIT:
168+ self ._iter_unit = TABLE_UNIT
169+ elif self ._iter_unit == ROW_UNIT:
170+ logger.debug(u ' The iterator has been built for fetching row' )
171+ raise RuntimeError
172+
173+ try :
174+ self ._current_chunk_row.init(self ._iter_unit) # AttributeError if it is iter(())
175+ while self ._chunk_index <= self ._chunk_count:
176+ table = self ._current_chunk_row.__next__()
177+ if self ._chunk_index < self ._chunk_count: # multiple chunks
178+ logger.debug(
179+ u " chunk index: %s , chunk_count: %s " ,
180+ self ._chunk_index, self ._chunk_count)
181+ next_chunk = self ._chunk_downloader.next_chunk()
182+ self ._current_chunk_row = next_chunk.result_data
183+ self ._current_chunk_row.init(self ._iter_unit)
184+ self ._chunk_index += 1
185+ yield table
186+ else :
187+ if self ._chunk_count > 0 and \
188+ self ._chunk_downloader is not None :
189+ self ._chunk_downloader.terminate()
190+ self ._cursor._log_telemetry_job_data(
191+ TelemetryField.TIME_DOWNLOADING_CHUNKS,
192+ self ._chunk_downloader._total_millis_downloading_chunks)
193+ self ._cursor._log_telemetry_job_data(
194+ TelemetryField.TIME_PARSING_CHUNKS,
195+ self ._chunk_downloader._total_millis_parsing_chunks)
196+ self ._chunk_downloader = None
197+ self ._chunk_count = 0
198+ self ._current_chunk_row = iter (())
199+ except AttributeError :
200+ # just for handling the case of empty result
201+ return None
202+ finally :
203+ if self ._cursor._first_chunk_time:
204+ logger.info(" fetching data into pandas dataframe done" )
205+ time_consume_last_result = get_time_millis() - self ._cursor._first_chunk_time
206+ self ._cursor._log_telemetry_job_data(
207+ TelemetryField.TIME_CONSUME_LAST_RESULT,
208+ time_consume_last_result)
149209
210+ def _fetch_arrow_all (self ):
211+ '''
212+ Fetch a single Arrow Table
213+ '''
214+ tables = list (self ._fetch_arrow_batches())
215+ if tables:
216+ return concat_tables(tables)
217+ else :
218+ return None
219+
220+ def _fetch_pandas_batches (self ):
221+ '''
222+ Fetch Pandas dataframes in batch, where 'batch' refers to Snowflake Chunk
223+ Thus, the batch size (the number of rows in dataframe) may be different
224+ TODO: take a look at pyarrow to_pandas() API, which provides some useful arguments
225+ e.g. 1. use `use_threads=true` for acceleration
226+ 2. use `strings_to_categorical` and `categories` to encoding categorical data,
227+ which is really different from `string` in data science.
228+ For example, some data may be marked as 0 and 1 as binary class in dataset,
229+ the user wishes to interpret as categorical data instead of integer.
230+ 3. use `zero_copy_only` to capture the potential unnecessary memory copying
231+ we'd better also provide these handy arguments to make data scientists happy :)
232+ '''
233+ for table in self ._fetch_arrow_batches():
234+ yield table.to_pandas()
235+
236+ def _fetch_pandas_all (self ):
237+ '''
238+ Fetch a single Pandas dataframe
239+ '''
240+ table = self ._fetch_arrow_all()
241+ if table:
242+ return table.to_pandas()
243+ else :
244+ return None
0 commit comments