Skip to content

Commit 85130be

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-84977: Enable exposing pandas dataframe with arrow format
1 parent debbb3c commit 85130be

14 files changed

+766
-46
lines changed

arrow_iterator.pyx

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,75 @@ from cpython.ref cimport PyObject
99

1010
logger = getLogger(__name__)
1111

12-
cdef extern from "cpp/ArrowIterator/CArrowChunkIterator.hpp" namespace "sf":
13-
cdef cppclass CArrowChunkIterator:
14-
CArrowChunkIterator(PyObject* context)
12+
'''
13+
the unit in this iterator
14+
EMPTY_UNIT: default
15+
ROW_UNIT: fetch row by row if the user call `fetchone()`
16+
TABLE_UNIT: fetch one arrow table if the user call `fetch_pandas()`
17+
'''
18+
ROW_UNIT, TABLE_UNIT, EMPTY_UNIT = 'row', 'table', ''
19+
1520

21+
cdef extern from "cpp/ArrowIterator/CArrowIterator.hpp" namespace "sf":
22+
cdef cppclass CArrowIterator:
1623
void addRecordBatch(PyObject * rb)
1724

18-
PyObject *nextRow();
25+
PyObject* next();
1926

2027
void reset();
2128

2229

23-
cdef class PyArrowChunkIterator:
24-
cdef CArrowChunkIterator* cIterator
30+
cdef extern from "cpp/ArrowIterator/CArrowChunkIterator.hpp" namespace "sf":
31+
cdef cppclass CArrowChunkIterator(CArrowIterator):
32+
CArrowChunkIterator(PyObject* context) except +
33+
34+
35+
cdef extern from "cpp/ArrowIterator/CArrowTableIterator.hpp" namespace "sf":
36+
cdef cppclass CArrowTableIterator(CArrowIterator):
37+
CArrowTableIterator(PyObject* context) except +
38+
39+
40+
cdef class PyArrowIterator:
41+
cdef object reader
42+
cdef object context
43+
cdef CArrowIterator* cIterator
44+
cdef str unit
2545
cdef PyObject* cret
2646

27-
def __cinit__(PyArrowChunkIterator self, object arrow_stream_reader, object arrow_context):
28-
self.cIterator = new CArrowChunkIterator(<PyObject*>arrow_context)
29-
for rb in arrow_stream_reader:
30-
self.cIterator.addRecordBatch(<PyObject*>rb)
31-
self.cIterator.reset()
47+
def __cinit__(self, object arrow_stream_reader, object arrow_context):
48+
self.reader = arrow_stream_reader
49+
self.context = arrow_context
50+
self.cIterator = NULL
51+
self.unit = ''
3252

33-
def __dealloc__(PyArrowChunkIterator self):
53+
def __dealloc__(self):
3454
del self.cIterator
3555

36-
def __next__(PyArrowChunkIterator self):
37-
cret = self.cIterator.nextRow()
38-
if not cret:
39-
logger.error("Internal error from CArrowChunkIterator\n")
56+
def __next__(self):
57+
self.cret = self.cIterator.next()
58+
59+
if not self.cret:
60+
logger.error("Internal error from CArrowIterator\n")
4061
# it looks like this line can help us get into python and detect the global variable immediately
4162
# however, this log will not show up for unclear reason
42-
ret = <object>cret
63+
ret = <object>self.cret
4364

4465
if ret is None:
4566
raise StopIteration
4667
else:
4768
return ret
69+
70+
def init(self, str iter_unit):
71+
# init chunk (row) iterator or table iterator
72+
if iter_unit != ROW_UNIT and iter_unit != TABLE_UNIT:
73+
raise NotImplementedError
74+
elif iter_unit == ROW_UNIT:
75+
self.cIterator = new CArrowChunkIterator(<PyObject*>self.context)
76+
elif iter_unit == TABLE_UNIT:
77+
self.cIterator = new CArrowTableIterator(<PyObject*>self.context)
78+
self.unit = iter_unit
79+
80+
# read
81+
for rb in self.reader:
82+
self.cIterator.addRecordBatch(<PyObject*>rb)
83+
self.cIterator.reset()

arrow_result.pyx

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ from .telemetry import TelemetryField
1010
from .time_util import get_time_millis
1111
try:
1212
from pyarrow.ipc import open_stream
13-
from .arrow_iterator import PyArrowChunkIterator
13+
from pyarrow import concat_tables
14+
from .arrow_iterator import PyArrowIterator, ROW_UNIT, TABLE_UNIT, EMPTY_UNIT
1415
from .arrow_context import ArrowConverterContext
1516
except ImportError:
1617
pass
@@ -32,6 +33,7 @@ cdef class ArrowResult:
3233
object _current_chunk_row
3334
object _chunk_downloader
3435
object _arrow_context
36+
str _iter_unit
3537

3638
def __init__(self, raw_response, cursor):
3739
self._reset()
@@ -51,9 +53,10 @@ cdef class ArrowResult:
5153
arrow_bytes = b64decode(rowset_b64)
5254
arrow_reader = open_stream(arrow_bytes)
5355
self._arrow_context = ArrowConverterContext(self._connection._session_parameters)
54-
self._current_chunk_row = PyArrowChunkIterator(arrow_reader, self._arrow_context)
56+
self._current_chunk_row = PyArrowIterator(arrow_reader, self._arrow_context)
5557
else:
56-
self._current_chunk_row = iter([])
58+
self._current_chunk_row = iter(())
59+
self._iter_unit = EMPTY_UNIT
5760

5861
if u'chunks' in data:
5962
chunks = data[u'chunks']
@@ -83,6 +86,13 @@ cdef class ArrowResult:
8386
return self
8487

8588
def __next__(self):
89+
if self._iter_unit == EMPTY_UNIT:
90+
self._iter_unit = ROW_UNIT
91+
self._current_chunk_row.init(self._iter_unit)
92+
elif self._iter_unit == TABLE_UNIT:
93+
logger.debug(u'The iterator has been built for fetching arrow table')
94+
raise RuntimeError
95+
8696
is_done = False
8797
try:
8898
row = None
@@ -96,6 +106,7 @@ cdef class ArrowResult:
96106
self._chunk_index, self._chunk_count)
97107
next_chunk = self._chunk_downloader.next_chunk()
98108
self._current_chunk_row = next_chunk.result_data
109+
self._current_chunk_row.init(self._iter_unit)
99110
self._chunk_index += 1
100111
try:
101112
row = self._current_chunk_row.__next__()
@@ -146,4 +157,88 @@ cdef class ArrowResult:
146157
self._chunk_count = 0
147158
self._chunk_downloader = None
148159
self._arrow_context = None
160+
self._iter_unit = EMPTY_UNIT
161+
162+
def _fetch_arrow_batches(self):
163+
'''
164+
Fetch Arrow Table in batch, where 'batch' refers to Snowflake Chunk
165+
Thus, the batch size (the number of rows in table) may be different
166+
'''
167+
if self._iter_unit == EMPTY_UNIT:
168+
self._iter_unit = TABLE_UNIT
169+
elif self._iter_unit == ROW_UNIT:
170+
logger.debug(u'The iterator has been built for fetching row')
171+
raise RuntimeError
172+
173+
try:
174+
self._current_chunk_row.init(self._iter_unit) # AttributeError if it is iter(())
175+
while self._chunk_index <= self._chunk_count:
176+
table = self._current_chunk_row.__next__()
177+
if self._chunk_index < self._chunk_count: # multiple chunks
178+
logger.debug(
179+
u"chunk index: %s, chunk_count: %s",
180+
self._chunk_index, self._chunk_count)
181+
next_chunk = self._chunk_downloader.next_chunk()
182+
self._current_chunk_row = next_chunk.result_data
183+
self._current_chunk_row.init(self._iter_unit)
184+
self._chunk_index += 1
185+
yield table
186+
else:
187+
if self._chunk_count > 0 and \
188+
self._chunk_downloader is not None:
189+
self._chunk_downloader.terminate()
190+
self._cursor._log_telemetry_job_data(
191+
TelemetryField.TIME_DOWNLOADING_CHUNKS,
192+
self._chunk_downloader._total_millis_downloading_chunks)
193+
self._cursor._log_telemetry_job_data(
194+
TelemetryField.TIME_PARSING_CHUNKS,
195+
self._chunk_downloader._total_millis_parsing_chunks)
196+
self._chunk_downloader = None
197+
self._chunk_count = 0
198+
self._current_chunk_row = iter(())
199+
except AttributeError:
200+
# just for handling the case of empty result
201+
return None
202+
finally:
203+
if self._cursor._first_chunk_time:
204+
logger.info("fetching data into pandas dataframe done")
205+
time_consume_last_result = get_time_millis() - self._cursor._first_chunk_time
206+
self._cursor._log_telemetry_job_data(
207+
TelemetryField.TIME_CONSUME_LAST_RESULT,
208+
time_consume_last_result)
149209

210+
def _fetch_arrow_all(self):
211+
'''
212+
Fetch a single Arrow Table
213+
'''
214+
tables = list(self._fetch_arrow_batches())
215+
if tables:
216+
return concat_tables(tables)
217+
else:
218+
return None
219+
220+
def _fetch_pandas_batches(self):
221+
'''
222+
Fetch Pandas dataframes in batch, where 'batch' refers to Snowflake Chunk
223+
Thus, the batch size (the number of rows in dataframe) may be different
224+
TODO: take a look at pyarrow to_pandas() API, which provides some useful arguments
225+
e.g. 1. use `use_threads=true` for acceleration
226+
2. use `strings_to_categorical` and `categories` to encoding categorical data,
227+
which is really different from `string` in data science.
228+
For example, some data may be marked as 0 and 1 as binary class in dataset,
229+
the user wishes to interpret as categorical data instead of integer.
230+
3. use `zero_copy_only` to capture the potential unnecessary memory copying
231+
we'd better also provide these handy arguments to make data scientists happy :)
232+
'''
233+
for table in self._fetch_arrow_batches():
234+
yield table.to_pandas()
235+
236+
def _fetch_pandas_all(self):
237+
'''
238+
Fetch a single Pandas dataframe
239+
'''
240+
table = self._fetch_arrow_all()
241+
if table:
242+
return table.to_pandas()
243+
else:
244+
return None

chunk_downloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
try:
2222
from pyarrow.ipc import open_stream
23-
from .arrow_iterator import PyArrowChunkIterator
23+
from .arrow_iterator import PyArrowIterator
2424
from .arrow_context import ArrowConverterContext
2525
except ImportError:
2626
pass
@@ -332,5 +332,5 @@ def __init__(self, meta, connection):
332332
def to_iterator(self, raw_data_fd, download_time):
333333
gzip_decoder = GzipFile(fileobj=raw_data_fd, mode='r')
334334
reader = open_stream(gzip_decoder)
335-
it = PyArrowChunkIterator(reader, self._arrow_context)
335+
it = PyArrowIterator(reader, self._arrow_context)
336336
return it

cpp/ArrowIterator/CArrowChunkIterator.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
namespace sf
1818
{
19-
Logger CArrowChunkIterator::logger("snowflake.connector.CArrowChunkIterator");
2019

2120
CArrowChunkIterator::CArrowChunkIterator(PyObject* context)
2221
: m_latestReturnedRow(nullptr), m_context(context)
@@ -25,9 +24,8 @@ CArrowChunkIterator::CArrowChunkIterator(PyObject* context)
2524

2625
void CArrowChunkIterator::addRecordBatch(PyObject* rb)
2726
{
28-
std::shared_ptr<arrow::RecordBatch> cRecordBatch;
29-
arrow::Status status = arrow::py::unwrap_record_batch(rb, &cRecordBatch);
30-
m_cRecordBatches.push_back(cRecordBatch);
27+
// may add some specific behaviors for this iterator
28+
CArrowIterator::addRecordBatch(rb);
3129
}
3230

3331
void CArrowChunkIterator::reset()
@@ -43,7 +41,7 @@ void CArrowChunkIterator::reset()
4341
m_columnCount);
4442
}
4543

46-
PyObject* CArrowChunkIterator::nextRow()
44+
PyObject* CArrowChunkIterator::next()
4745
{
4846
m_rowIndexInBatch++;
4947

cpp/ArrowIterator/CArrowChunkIterator.hpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
#ifndef PC_ARROWCHUNKITERATOR_HPP
55
#define PC_ARROWCHUNKITERATOR_HPP
66

7-
#include <Python.h>
8-
#include <vector>
9-
#include <arrow/python/platform.h>
10-
#include <arrow/api.h>
11-
#include <arrow/python/pyarrow.h>
7+
#include "CArrowIterator.hpp"
128
#include "IColumnConverter.hpp"
13-
#include "logging.hpp"
149
#include "Python/Common.hpp"
1510

1611
namespace sf
@@ -21,7 +16,7 @@ namespace sf
2116
* iterator object)
2217
* will ask for nextRow to be returned back to Python
2318
*/
24-
class CArrowChunkIterator
19+
class CArrowChunkIterator : public CArrowIterator
2520
{
2621
public:
2722
/**
@@ -38,19 +33,16 @@ class CArrowChunkIterator
3833
* Add Arrow RecordBach to current chunk
3934
* @param rb recordbatch to be added
4035
*/
41-
void addRecordBatch(PyObject* rb);
36+
void addRecordBatch(PyObject* rb) override;
4237

4338
/**
4439
* @return a python tuple object which contains all data in current row
4540
*/
46-
PyObject* nextRow();
41+
PyObject* next() override;
4742

48-
void reset();
43+
void reset() override;
4944

5045
private:
51-
/** list of all record batch in current chunk */
52-
std::vector<std::shared_ptr<arrow::RecordBatch>> m_cRecordBatches;
53-
5446
/** number of columns */
5547
int m_columnCount;
5648

@@ -80,8 +72,6 @@ class CArrowChunkIterator
8072
*/
8173
void currentRowAsTuple();
8274

83-
static Logger logger;
84-
8575
void initColumnConverters();
8676
};
8777
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright (c) 2013-2019 Snowflake Computing
3+
*/
4+
5+
#include "CArrowIterator.hpp"
6+
7+
namespace sf
8+
{
9+
10+
Logger CArrowIterator::logger("snowflake.connector.CArrowIterator");
11+
12+
void CArrowIterator::addRecordBatch(PyObject* rb)
13+
{
14+
std::shared_ptr<arrow::RecordBatch> cRecordBatch;
15+
arrow::Status status = arrow::py::unwrap_record_batch(rb, &cRecordBatch);
16+
m_cRecordBatches.push_back(cRecordBatch);
17+
}
18+
19+
}

0 commit comments

Comments
 (0)